P5405 [CTS2019] 氪金手游

$n$ 个点的树形图,$u$ 的点权记为 $w_u$,$w_u$ 分别有 $p_{u,1},p_{u,2},p_{u,3}$ 的概率取到 $1,2,3$。

涅普缇努会随机选择节点多次,直到节点都被选择过,一次选择中选到节点 $u$ 的概率为 $\frac{w_u}{\sum w_i}$。

记 $t_u$ 为节点 $u$ 第一次被选择的时间,树边 $u \to v$,会限制 $t_u < t_v$,涅普缇努想知道所有树边的限制都被满足的概率。

$1 \le n \le 10^3$。

先考虑叶向树怎么做。此时的限制是 $u$ 子树中的点(不包括 $u$)都比 $u$ 晚出现。考虑对于 $u$ 的限制,子树外的点都不用考虑,删去它们,那么满足限制的概率为 $\frac{w_u}{\sum_{i \in T(u)}w_i}$。该式只与 $u$ 子树中的点有关,于是可以树形 DP,记录子树中的点权即可,类似树形背包,时间复杂度 $O(n^2)$。

反向边是不好做的,而不考虑反向边(相当于删去)和反向边一定不成立(相当于正向边)是好做的,可以容斥,记 $f_i$ 为至少 $i$ 条反向边不成立的概率,答案为恰好 $0$ 条反向边不成立,等于 $\sum_{i=0}^{n-1} (-1)^i f_i$。进行带容斥系数的 DP 即可,时间复杂度 $O(n^2)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include<bits/stdc++.h>
using namespace std;

#define pb push_back
#define ar(x) array<int,x>
const int N=1e3+10,INF=1e9,MOD=998244353;
inline void mod(int &x){ if(x>=MOD) x-=MOD; if(x<0) x+=MOD; }
inline int qm(int x,int y=MOD-2){ int res=1; for(;y;y>>=1,x=1ll*x*x%MOD) if(y&1) res=1ll*res*x%MOD; return res; }
#define gc getchar()
#define rd read()
inline int read(){
int x=0,f=0; char c=gc;
for(;c<'0'||c>'9';c=gc) f|=(c=='-');
for(;c>='0'&&c<='9';c=gc) x=(x<<1)+(x<<3)+(c^48);
return f?-x:x;
}

int n,f[N][N*3],p[N][4],tmp[N*3],siz[N];
vector<ar(2)> G[N];

void dfs(int u,int fu){
f[u][siz[u]=0]=1;
for(auto [v,w]:G[u]){
if(v==fu) continue; dfs(v,u);
if(w==0){
for(int i=0;i<=3*siz[u];++i) tmp[i]=f[u][i],f[u][i]=0;
for(int i=0;i<=3*siz[u];++i)
for(int j=0;j<=3*siz[v];++j)
mod(f[u][i+j]+=1ll*tmp[i]*f[v][j]%MOD);
}
else{
int all=0; for(int i=0;i<=3*siz[v];++i) mod(all+=f[v][i]);
for(int i=0;i<=3*siz[u];++i) tmp[i]=f[u][i],f[u][i]=1ll*f[u][i]*all%MOD;
for(int i=0;i<=3*siz[u];++i)
for(int j=0;j<=3*siz[v];++j)
mod(f[u][i+j]-=1ll*tmp[i]*f[v][j]%MOD);
}
siz[u]+=siz[v];
}
for(int i=0;i<=3*siz[u];++i) tmp[i]=f[u][i],f[u][i]=0;
for(int i=0;i<=3*siz[u];++i)
for(int j=1;j<=3;++j)
mod(f[u][i+j]+=1ll*tmp[i]*p[u][j]%MOD*j%MOD*qm(i+j)%MOD);
++siz[u];
}

int main(){

n=rd;
for(int i=1,x,y,z,inv;i<=n;++i){
x=rd,y=rd,z=rd,inv=qm(x+y+z);
p[i][1]=1ll*x*inv%MOD;
p[i][2]=1ll*y*inv%MOD;
p[i][3]=1ll*z*inv%MOD;
}
for(int i=1,x,y;i<=n-1;++i) x=rd,y=rd,G[x].pb({y,0}),G[y].pb({x,1});

dfs(1,0);

int ans=0; for(int i=0;i<=3*n;++i) mod(ans+=f[1][i]);
printf("%d\n", ans);

return 0;
}