diff --git a/src/main/java/com/rymcu/forest/util/SSRFUtil.java b/src/main/java/com/rymcu/forest/util/SSRFUtil.java new file mode 100644 index 0000000..bd9d921 --- /dev/null +++ b/src/main/java/com/rymcu/forest/util/SSRFUtil.java @@ -0,0 +1,157 @@ +package com.rymcu.forest.util; + +import com.google.common.net.InternetDomainName; + +import java.net.MalformedURLException; +import java.net.URL; +import java.util.Objects; + +/** + * Created on 2023/12/29 11:52. + * + * @author ronger + * @email ronger-x@outlook.com + * @desc : com.rymcu.forest.util + */ +public class SSRFUtil { + public static boolean checkUrl(URL url, boolean checkWhiteList) { + // 协议限制 + if (!url.getProtocol().startsWith("http") && !url.getProtocol().startsWith("https")) { + return false; + } + try { + // 获取域名,并转为小写 + String host = url.getHost().toLowerCase(); + // 禁止内网 IP + if (!internalIp(host)) { + return false; + } + if (checkWhiteList) { + // 获取一级域名 + String rootDomain = InternetDomainName.from(host).topPrivateDomain().toString(); + // TODO 白名单 + } + } catch (IllegalArgumentException exception) { + return false; + } + return true; + } + + public static void main(String[] args) throws MalformedURLException { + URL url = new URL("http://192.168.0.1"); + boolean b = checkUrl(url, false); + System.out.println(b); + } + + public static boolean internalIp(String ip) { + byte[] addr = textToNumericFormatV4(ip); + return internalIp(addr) || "127.0.0.1".equals(ip); + } + + private static boolean internalIp(byte[] addr) { + if (Objects.isNull(addr) || addr.length < 2) { + return true; + } + final byte b0 = addr[0]; + final byte b1 = addr[1]; + // 10.x.x.x/8 + final byte SECTION_1 = 0x0A; + // 172.16.x.x/12 + final byte SECTION_2 = (byte) 0xAC; + final byte SECTION_3 = (byte) 0x10; + final byte SECTION_4 = (byte) 0x1F; + // 192.168.x.x/16 + final byte SECTION_5 = (byte) 0xC0; + final byte SECTION_6 = (byte) 0xA8; + switch (b0) { + case SECTION_1: + return true; + case SECTION_2: + if (b1 >= SECTION_3 && b1 <= SECTION_4) { + return true; + } + case SECTION_5: + switch (b1) { + case SECTION_6: + return true; + } + default: + return false; + } + } + + /** + * 将IPv4地址转换成字节 + * + * @param text IPv4地址 + * @return byte 字节 + */ + public static byte[] textToNumericFormatV4(String text) { + if (text.isEmpty()) { + return null; + } + + byte[] bytes = new byte[4]; + String[] elements = text.split("\\.", -1); + try { + long l; + int i; + switch (elements.length) { + case 1: + l = Long.parseLong(elements[0]); + if ((l < 0L) || (l > 4294967295L)) { + return null; + } + bytes[0] = (byte) (int) (l >> 24 & 0xFF); + bytes[1] = (byte) (int) ((l & 0xFFFFFF) >> 16 & 0xFF); + bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF); + bytes[3] = (byte) (int) (l & 0xFF); + break; + case 2: + l = Integer.parseInt(elements[0]); + if ((l < 0L) || (l > 255L)) { + return null; + } + bytes[0] = (byte) (int) (l & 0xFF); + l = Integer.parseInt(elements[1]); + if ((l < 0L) || (l > 16777215L)) { + return null; + } + bytes[1] = (byte) (int) (l >> 16 & 0xFF); + bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF); + bytes[3] = (byte) (int) (l & 0xFF); + break; + case 3: + for (i = 0; i < 2; ++i) { + l = Integer.parseInt(elements[i]); + if ((l < 0L) || (l > 255L)) { + return null; + } + bytes[i] = (byte) (int) (l & 0xFF); + } + l = Integer.parseInt(elements[2]); + if ((l < 0L) || (l > 65535L)) { + return null; + } + bytes[2] = (byte) (int) (l >> 8 & 0xFF); + bytes[3] = (byte) (int) (l & 0xFF); + break; + case 4: + for (i = 0; i < 4; ++i) { + l = Integer.parseInt(elements[i]); + if ((l < 0L) || (l > 255L)) { + return null; + } + bytes[i] = (byte) (int) (l & 0xFF); + } + break; + default: + return null; + } + } catch (NumberFormatException e) { + return null; + } + return bytes; + } + +} diff --git a/src/main/java/com/rymcu/forest/web/api/common/UploadController.java b/src/main/java/com/rymcu/forest/web/api/common/UploadController.java index aa5389c..376f655 100644 --- a/src/main/java/com/rymcu/forest/web/api/common/UploadController.java +++ b/src/main/java/com/rymcu/forest/web/api/common/UploadController.java @@ -25,13 +25,11 @@ import org.springframework.web.multipart.MultipartFile; import javax.annotation.Resource; import javax.servlet.http.HttpServletRequest; -import java.io.ByteArrayOutputStream; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; +import java.io.*; import java.net.HttpURLConnection; import java.net.URL; import java.util.*; +import com.rymcu.forest.util.SSRFUtil; /** * 文件上传控制器 @@ -265,6 +263,10 @@ public class UploadController { return GlobalResultGenerator.genSuccessResult(data); } URL link = new URL(url); + // SSRF 校验 + if (!SSRFUtil.checkUrl(link, false)) { + throw new FileNotFoundException(); + } HttpURLConnection conn = (HttpURLConnection) link.openConnection(); //设置超时间为3秒 conn.setConnectTimeout(3 * 1000);