1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
| #include<bits/stdc++.h> using namespace std;
#define pb push_back #define ar(x) array<int,x> const int N=1e3+10,INF=1e9,MOD=998244353; inline void mod(int &x){ if(x>=MOD) x-=MOD; if(x<0) x+=MOD; } inline int qm(int x,int y=MOD-2){ int res=1; for(;y;y>>=1,x=1ll*x*x%MOD) if(y&1) res=1ll*res*x%MOD; return res; } #define gc getchar() #define rd read() inline int read(){ int x=0,f=0; char c=gc; for(;c<'0'||c>'9';c=gc) f|=(c=='-'); for(;c>='0'&&c<='9';c=gc) x=(x<<1)+(x<<3)+(c^48); return f?-x:x; }
int n,f[N][N*3],p[N][4],tmp[N*3],siz[N]; vector<ar(2)> G[N];
void dfs(int u,int fu){ f[u][siz[u]=0]=1; for(auto [v,w]:G[u]){ if(v==fu) continue; dfs(v,u); if(w==0){ for(int i=0;i<=3*siz[u];++i) tmp[i]=f[u][i],f[u][i]=0; for(int i=0;i<=3*siz[u];++i) for(int j=0;j<=3*siz[v];++j) mod(f[u][i+j]+=1ll*tmp[i]*f[v][j]%MOD); } else{ int all=0; for(int i=0;i<=3*siz[v];++i) mod(all+=f[v][i]); for(int i=0;i<=3*siz[u];++i) tmp[i]=f[u][i],f[u][i]=1ll*f[u][i]*all%MOD; for(int i=0;i<=3*siz[u];++i) for(int j=0;j<=3*siz[v];++j) mod(f[u][i+j]-=1ll*tmp[i]*f[v][j]%MOD); } siz[u]+=siz[v]; } for(int i=0;i<=3*siz[u];++i) tmp[i]=f[u][i],f[u][i]=0; for(int i=0;i<=3*siz[u];++i) for(int j=1;j<=3;++j) mod(f[u][i+j]+=1ll*tmp[i]*p[u][j]%MOD*j%MOD*qm(i+j)%MOD); ++siz[u]; }
int main(){
n=rd; for(int i=1,x,y,z,inv;i<=n;++i){ x=rd,y=rd,z=rd,inv=qm(x+y+z); p[i][1]=1ll*x*inv%MOD; p[i][2]=1ll*y*inv%MOD; p[i][3]=1ll*z*inv%MOD; } for(int i=1,x,y;i<=n-1;++i) x=rd,y=rd,G[x].pb({y,0}),G[y].pb({x,1});
dfs(1,0);
int ans=0; for(int i=0;i<=3*n;++i) mod(ans+=f[1][i]); printf("%d\n", ans);
return 0; }
|