算法学习-最小生成树模板(JAVA实现)

本文最后更新于:April 6, 2022 pm

纸上得来终觉浅,绝知此事要躬行。路漫漫其修远兮,吾将上下而求索!知识是经过历史的巨人沉淀下来的,别总想着自己能够快速学会,多花点时间去看看,也许会发现些不同的东西。你能快速学会的、觉得简单的东西,对于别人来说也是一样的。人外有人,天外有天。当努力到达了一定的程度,幸运自会与你不期而遇。

目录

提示:n表示点数,m表示边数。

练习题目 (代码以此题为例)

练习题目

练习题目 此题只能用Scanner输入,因为需要判断是否有下一次输入。(至少我只知道这一种)

Prim

朴素版Prim

适用于稠密图。时间复杂度:O(n2)

核心思想:每次挑一条与当前集合相连的最短边。

最短边的定义为:取当前点和集合中所有点比较后的最短距离。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import java.io.*;
import java.math.BigInteger;
import java.util.*;


/**
* @Author DragonOne
* @Date 2021/12/5 21:27
* @墨水记忆 www.tothefor.com
*/
public class Main {
public static BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
public static BufferedWriter out = new BufferedWriter(new OutputStreamWriter(System.out));
public static StreamTokenizer cin = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
public static PrintWriter cout = new PrintWriter(new OutputStreamWriter(System.out));
public static Scanner sc = new Scanner(System.in);

public static int maxd = 200000+7;
public static int INF = 0x3f3f3f3f;
public static int mod = (int) 1e9+7;
public static int[] dis = new int[maxd];
public static int[] head = new int[maxd];
public static int[] edgePre = new int[maxd<<1]; //无向图,边需要开2倍
public static int[] edgeW = new int[maxd<<1];
public static int[] edgeTo = new int[maxd<<1];
public static boolean[] vis = new boolean[maxd];
public static int n;
public static int node=0;

public static void add_edge(int a,int b,int c){
edgeTo[node] = b;
edgeW[node] = c;
edgePre[node] = head[a];
head[a]=node++;
}

public static void init(int n){
for(int i=0;i<=n;++i){
dis[i]=INF;
vis[i]=false;
head[i] = -1;
}
}
public static int Prim(){//默认找到的第一个为集合的首元素
int res = 0;
for(int i=1;i<=n;++i){
int t = -1;
for(int j=1;j<=n;++j){
if(!vis[j] && (t==-1 || dis[t]>dis[j]))
t = j;
}
vis[t]=true;
if(i!=1 && dis[t]==INF) return INF; //当前点与集合中的所有点都不连通,不存在最小生成树
if(i!=1) res+=dis[t]; //首元素不需要加
for(int j=head[t];j!=-1;j=edgePre[j]){
int to = edgeTo[j];
dis[to] = Math.min(dis[to],edgeW[j]);
}
}
return res;
}


public static void main(String[] args) throws Exception {

n = nextInt();
int m = nextInt();
init(n);
while(m-->0){
int a = nextInt();
int b = nextInt();
int c = nextInt();
add_edge(a,b,c); //无向图
add_edge(b,a,c);
}
int ans = Prim();
cout.println(ans==INF? "orz":ans);
cout.flush();
closeAll();
}

public static void cinInit(){
cin.wordChars('a', 'z');
cin.wordChars('A', 'Z');
cin.wordChars(128 + 32, 255);
cin.whitespaceChars(0, ' ');
cin.commentChar('/');
cin.quoteChar('"');
cin.quoteChar('\'');
cin.parseNumbers(); //可单独使用来还原数字
}

public static int nextInt() throws Exception{
cin.nextToken();
return (int) cin.nval;
}
public static long nextLong() throws Exception{
cin.nextToken();
return (long) cin.nval;
}
public static double nextDouble() throws Exception{
cin.nextToken();
return cin.nval;
}
public static String nextString() throws Exception{
cin.nextToken();
return cin.sval;
}
public static void closeAll() throws Exception {
cout.close();
in.close();
out.close();
}

}

堆优化Prim

时间复杂度 O(m*logn)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import java.io.*;
import java.math.BigInteger;
import java.sql.Time;
import java.text.SimpleDateFormat;
import java.util.*;

/**
* @Author DragonOne
* @Date 2021/12/5 21:27
* @墨水记忆 www.tothefor.com
*/

public class Main {
public static BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
public static BufferedWriter out = new BufferedWriter(new OutputStreamWriter(System.out));
public static StreamTokenizer cin = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
public static PrintWriter cout = new PrintWriter(new OutputStreamWriter(System.out));
public static Scanner sc = new Scanner(System.in);
/**
* 每月的天数
*/
public static int monthes[] = {0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31};
public static int maxd = 200000 + 15;
public static int INF = 0x3f3f3f3f;
public static int mod = (int) 1e9 + 7;
public static int[][] fx = {{1, 0}, {0, 1}, {-1, 0}, {0, -1}};

public static int node = 0;
public static int[] head = new int[maxd];
public static int[] edgePre = new int[maxd*2];
public static int[] edgeW = new int[maxd*2];
public static int[] edgeTo = new int[maxd*2];
public static int[] dis = new int[maxd]; //点到集合的最短距离
public static boolean[] vis = new boolean[maxd];

static class Edge implements Comparable<Edge> {
private int point; //点
private int w; //此点到集合的最短距离
Edge(int point, int w) {
this.point = point;
this.w = w;
}
@Override
public int compareTo(Edge obj) {
return this.w - obj.w;
}
}

public static void init(int n){
for(int i=1;i<=n;++i){
head[i] = -1;
dis[i] = INF;
vis[i] = false;
}
node=0;
}
public static void add_edge(int a,int b,int c){
edgeTo[node]=b;
edgeW[node]=c;
edgePre[node] = head[a];
head[a]=node++;
}
public static int Prim(int start ,int n){
PriorityQueue<Edge> q = new PriorityQueue<>();
int ans = 0;
int cnt = 0; //记录集合中的点数,只有等于给定点数n时才存在最小生成树,小于n则表示图不连通
q.offer(new Edge(start,0));
while (!q.isEmpty()){
Edge now = q.poll();
if(vis[now.point]) continue;
vis[now.point]=true;
ans+=now.w;
cnt++;
for(int i=head[now.point];i!=-1;i=edgePre[i]){
int to = edgeTo[i];
if(dis[to]>edgeW[i]){
dis[to] = edgeW[i];
q.offer(new Edge(to,dis[to]));
}
}
}

if(cnt!=n) return INF;
return ans;

}

public static void main(String[] args) throws Exception {

int n =nextInt();
int m = nextInt();
init(n);
for(int i=1;i<=m;++i){
int a = nextInt();
int b = nextInt();
int c = nextInt();
add_edge(a,b,c);
add_edge(b,a,c);
}
int prim = Prim(1,n);
System.out.println(prim==INF?"orz":prim);


closeAll();
}

public static int gcd(int a, int b) { // 不需要判断a和b的大小
while (b > 0) {
a %= b;
b ^= a;
a ^= b;
b ^= a;
}
return a;
}

public static int log(int x) { //log方法是以2为底,求x的对数.而java自带的log是以e为底的
return (int) (Math.log(x) / Math.log(2));
}

public static void cinInit() {
cin.wordChars('a', 'z');
cin.wordChars('A', 'Z');
cin.wordChars(128 + 32, 255);
cin.whitespaceChars(0, ' ');
cin.commentChar('/');
cin.quoteChar('"');
cin.quoteChar('\'');
cin.parseNumbers(); //可单独使用还原数字
}

public static int nextInt() throws Exception {
cin.nextToken();
return (int) cin.nval;
}

public static long nextLong() throws Exception {
cin.nextToken();
return (long) cin.nval;
}

public static double nextDouble() throws Exception {
cin.nextToken();
return cin.nval;
}

public static String nextString() throws Exception {
cin.nextToken();
return cin.sval;
}

public static void closeAll() throws Exception {
cout.close();
in.close();
out.close();
}
}

Kruskal

适用于稀疏图,时间复杂度 O(m*logm)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import java.io.*;
import java.math.BigInteger;
import java.util.*;


/**
* @Author DragonOne
* @Date 2021/12/5 21:27
* @墨水记忆 www.tothefor.com
*/
public class Main {
public static BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
public static BufferedWriter out = new BufferedWriter(new OutputStreamWriter(System.out));
public static StreamTokenizer cin = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));
public static PrintWriter cout = new PrintWriter(new OutputStreamWriter(System.out));
public static Scanner sc = new Scanner(System.in);

public static int maxd = 200000+7;
public static int INF = 0x3f3f3f3f;
public static int mod = (int) 1e9+7;
public static int[] par = new int[maxd]; //存储父节点
public static int n;

public static class Edge implements Comparable<Edge> {
private int u; //起点
private int v; //终点
private int w; //边的权重
public Edge(int u, int v, int w) {
this.u = u;
this.v = v;
this.w = w;
}
public int compareTo(Edge obj) {
return this.w - obj.w;
}
}
public static Edge[] edges = new Edge[maxd<<1]; //无向图,开两倍

public static void init(int n,int m){
for(int i=1;i<=n;++i){
par[i] = i;
}
}
public static int find(int x){
if(x!=par[x]) par[x]=find(par[x]);
return par[x];
}
public static void unite(int a,int b){
int aa = find(a);
int bb = find(b);
if(aa!=bb) {
par[aa]=bb;
}
}
public static int Kruskal(int m){
int res = 0; //结果
int cnt = 0; //记录边
Arrays.sort(edges,0,m);
for(int i=0;i<m;++i){
int u = edges[i].u;
int v = edges[i].v;
int w = edges[i].w;
if(find(u)!=find(v)){
unite(u,v); //合并
res+=w;
cnt++;
}
}
if(cnt<n-1) return INF; //若少于n-1条边,则说明图不连通
else return res;
}

public static void main(String[] args) throws Exception {

n = nextInt();
int m = nextInt();
init(n,m);
for(int i=0;i<m;++i){
int a = nextInt();
int b = nextInt();
int c = nextInt();
edges[i] = new Edge(a,b,c);
}

int ans = Kruskal(m);
cout.println(ans==INF? "orz":ans);
cout.flush();
closeAll();
}

public static void cinInit(){
cin.wordChars('a', 'z');
cin.wordChars('A', 'Z');
cin.wordChars(128 + 32, 255);
cin.whitespaceChars(0, ' ');
cin.commentChar('/');
cin.quoteChar('"');
cin.quoteChar('\'');
cin.parseNumbers(); //可单独使用来还原数字
}

public static int nextInt() throws Exception{
cin.nextToken();
return (int) cin.nval;
}
public static long nextLong() throws Exception{
cin.nextToken();
return (long) cin.nval;
}
public static double nextDouble() throws Exception{
cin.nextToken();
return cin.nval;
}
public static String nextString() throws Exception{
cin.nextToken();
return cin.sval;
}
public static void closeAll() throws Exception {
cout.close();
in.close();
out.close();
}

}