diff --git a/pom.xml b/pom.xml
index 69d90a5..7e6e0a3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -5,7 +5,7 @@
org.springframework.boot
spring-boot-starter-parent
- 2.7.13
+ 2.7.17
com.rymcu
@@ -39,8 +39,26 @@
org.yaml
snakeyaml
+
+ ch.qos.logback
+ logback-classic
+
+
+ ch.qos.logback
+ logback-core
+
+
+ ch.qos.logback
+ logback-classic
+ 1.4.11
+
+
+ ch.qos.logback
+ logback-core
+ 1.4.11
+
io.netty
netty-codec
@@ -83,7 +101,7 @@
mysql
mysql-connector-java
- 8.0.30
+ 8.0.33
runtime
@@ -105,7 +123,7 @@
org.apache.tomcat.embed
tomcat-embed-core
- 9.0.83
+ 9.0.84
org.springframework.boot
@@ -371,6 +389,11 @@
javax.validation
validation-api
+
+ com.google.guava
+ guava
+ 33.0.0-jre
+
diff --git a/src/main/java/com/rymcu/forest/util/SSRFUtil.java b/src/main/java/com/rymcu/forest/util/SSRFUtil.java
new file mode 100644
index 0000000..bd9d921
--- /dev/null
+++ b/src/main/java/com/rymcu/forest/util/SSRFUtil.java
@@ -0,0 +1,157 @@
+package com.rymcu.forest.util;
+
+import com.google.common.net.InternetDomainName;
+
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.util.Objects;
+
+/**
+ * Created on 2023/12/29 11:52.
+ *
+ * @author ronger
+ * @email ronger-x@outlook.com
+ * @desc : com.rymcu.forest.util
+ */
+public class SSRFUtil {
+ public static boolean checkUrl(URL url, boolean checkWhiteList) {
+ // 协议限制
+ if (!url.getProtocol().startsWith("http") && !url.getProtocol().startsWith("https")) {
+ return false;
+ }
+ try {
+ // 获取域名,并转为小写
+ String host = url.getHost().toLowerCase();
+ // 禁止内网 IP
+ if (!internalIp(host)) {
+ return false;
+ }
+ if (checkWhiteList) {
+ // 获取一级域名
+ String rootDomain = InternetDomainName.from(host).topPrivateDomain().toString();
+ // TODO 白名单
+ }
+ } catch (IllegalArgumentException exception) {
+ return false;
+ }
+ return true;
+ }
+
+ public static void main(String[] args) throws MalformedURLException {
+ URL url = new URL("http://192.168.0.1");
+ boolean b = checkUrl(url, false);
+ System.out.println(b);
+ }
+
+ public static boolean internalIp(String ip) {
+ byte[] addr = textToNumericFormatV4(ip);
+ return internalIp(addr) || "127.0.0.1".equals(ip);
+ }
+
+ private static boolean internalIp(byte[] addr) {
+ if (Objects.isNull(addr) || addr.length < 2) {
+ return true;
+ }
+ final byte b0 = addr[0];
+ final byte b1 = addr[1];
+ // 10.x.x.x/8
+ final byte SECTION_1 = 0x0A;
+ // 172.16.x.x/12
+ final byte SECTION_2 = (byte) 0xAC;
+ final byte SECTION_3 = (byte) 0x10;
+ final byte SECTION_4 = (byte) 0x1F;
+ // 192.168.x.x/16
+ final byte SECTION_5 = (byte) 0xC0;
+ final byte SECTION_6 = (byte) 0xA8;
+ switch (b0) {
+ case SECTION_1:
+ return true;
+ case SECTION_2:
+ if (b1 >= SECTION_3 && b1 <= SECTION_4) {
+ return true;
+ }
+ case SECTION_5:
+ switch (b1) {
+ case SECTION_6:
+ return true;
+ }
+ default:
+ return false;
+ }
+ }
+
+ /**
+ * 将IPv4地址转换成字节
+ *
+ * @param text IPv4地址
+ * @return byte 字节
+ */
+ public static byte[] textToNumericFormatV4(String text) {
+ if (text.isEmpty()) {
+ return null;
+ }
+
+ byte[] bytes = new byte[4];
+ String[] elements = text.split("\\.", -1);
+ try {
+ long l;
+ int i;
+ switch (elements.length) {
+ case 1:
+ l = Long.parseLong(elements[0]);
+ if ((l < 0L) || (l > 4294967295L)) {
+ return null;
+ }
+ bytes[0] = (byte) (int) (l >> 24 & 0xFF);
+ bytes[1] = (byte) (int) ((l & 0xFFFFFF) >> 16 & 0xFF);
+ bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
+ bytes[3] = (byte) (int) (l & 0xFF);
+ break;
+ case 2:
+ l = Integer.parseInt(elements[0]);
+ if ((l < 0L) || (l > 255L)) {
+ return null;
+ }
+ bytes[0] = (byte) (int) (l & 0xFF);
+ l = Integer.parseInt(elements[1]);
+ if ((l < 0L) || (l > 16777215L)) {
+ return null;
+ }
+ bytes[1] = (byte) (int) (l >> 16 & 0xFF);
+ bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
+ bytes[3] = (byte) (int) (l & 0xFF);
+ break;
+ case 3:
+ for (i = 0; i < 2; ++i) {
+ l = Integer.parseInt(elements[i]);
+ if ((l < 0L) || (l > 255L)) {
+ return null;
+ }
+ bytes[i] = (byte) (int) (l & 0xFF);
+ }
+ l = Integer.parseInt(elements[2]);
+ if ((l < 0L) || (l > 65535L)) {
+ return null;
+ }
+ bytes[2] = (byte) (int) (l >> 8 & 0xFF);
+ bytes[3] = (byte) (int) (l & 0xFF);
+ break;
+ case 4:
+ for (i = 0; i < 4; ++i) {
+ l = Integer.parseInt(elements[i]);
+ if ((l < 0L) || (l > 255L)) {
+ return null;
+ }
+ bytes[i] = (byte) (int) (l & 0xFF);
+ }
+ break;
+ default:
+ return null;
+ }
+ } catch (NumberFormatException e) {
+ return null;
+ }
+ return bytes;
+ }
+
+}
diff --git a/src/main/java/com/rymcu/forest/web/api/common/UploadController.java b/src/main/java/com/rymcu/forest/web/api/common/UploadController.java
index aa5389c..376f655 100644
--- a/src/main/java/com/rymcu/forest/web/api/common/UploadController.java
+++ b/src/main/java/com/rymcu/forest/web/api/common/UploadController.java
@@ -25,13 +25,11 @@ import org.springframework.web.multipart.MultipartFile;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
-import java.io.ByteArrayOutputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
+import java.io.*;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.*;
+import com.rymcu.forest.util.SSRFUtil;
/**
* 文件上传控制器
@@ -265,6 +263,10 @@ public class UploadController {
return GlobalResultGenerator.genSuccessResult(data);
}
URL link = new URL(url);
+ // SSRF 校验
+ if (!SSRFUtil.checkUrl(link, false)) {
+ throw new FileNotFoundException();
+ }
HttpURLConnection conn = (HttpURLConnection) link.openConnection();
//设置超时间为3秒
conn.setConnectTimeout(3 * 1000);