Add @Override annotations to override methods before decomp

This commit is contained in:
Kyle Wood 2020-12-25 02:40:22 -08:00
parent 4c1a2ee227
commit e254b4e3fd
5 changed files with 242 additions and 103 deletions

View file

@ -309,6 +309,7 @@ class Paperweight : Plugin<Project> {
val fixJar by tasks.registering<FixJar> {
inputJar.set(remapJar.flatMap { it.outputJar })
vanillaJar.set(generalTasks.downloadServerJar.flatMap { it.outputJar })
}
val downloadMcLibraries by tasks.registering<DownloadMcLibraries> {

View file

@ -23,6 +23,7 @@
package io.papermc.paperweight.tasks
import io.papermc.paperweight.util.AsmUtil
import io.papermc.paperweight.util.SyntheticUtil
import io.papermc.paperweight.util.defaultOutput
import io.papermc.paperweight.util.file
import java.util.jar.JarFile
@ -33,7 +34,6 @@ import org.gradle.api.tasks.InputFile
import org.gradle.api.tasks.OutputFile
import org.gradle.api.tasks.TaskAction
import org.objectweb.asm.ClassReader
import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
@ -46,6 +46,9 @@ abstract class FixJar : BaseTask(), AsmUtil {
@get:InputFile
abstract val inputJar: RegularFileProperty
@get:InputFile
abstract val vanillaJar: RegularFileProperty
@get:OutputFile
abstract val outputJar: RegularFileProperty
@ -55,55 +58,55 @@ abstract class FixJar : BaseTask(), AsmUtil {
@TaskAction
fun run() {
JarOutputStream(outputJar.file.outputStream()).use { out ->
JarFile(inputJar.file).use { jarFile ->
for (entry in jarFile.entries()) {
if (!entry.name.endsWith(".class")) {
out.putNextEntry(entry)
JarFile(vanillaJar.file).use { vanillaJar ->
JarOutputStream(outputJar.file.outputStream()).use { out ->
JarFile(inputJar.file).use { jarFile ->
val classNodeCache = ClassNodeCache(jarFile, vanillaJar)
for (entry in jarFile.entries()) {
if (!entry.name.endsWith(".class")) {
out.putNextEntry(entry)
try {
jarFile.getInputStream(entry).copyTo(out)
} finally {
out.closeEntry()
}
continue
}
try {
jarFile.getInputStream(entry).copyTo(out)
val node =
classNodeCache.findClass(entry.name) ?: error("No ClassNode found for known entry")
ParameterAnnotationFixer(node).visitNode()
OverrideAnnotationAdder(node, classNodeCache).visitNode()
val writer = ClassWriter(0)
node.accept(writer)
out.putNextEntry(ZipEntry(entry.name))
out.write(writer.toByteArray())
out.flush()
} finally {
out.closeEntry()
}
continue
}
val classData = jarFile.getInputStream(entry).readBytes()
try {
val node = ClassNode(Opcodes.ASM9)
var visitor: ClassVisitor = node
visitor = ParameterAnnotationFixer(node, visitor)
val reader = ClassReader(classData)
reader.accept(visitor, 0)
val writer = ClassWriter(0)
node.accept(writer)
out.putNextEntry(ZipEntry(entry.name))
out.write(writer.toByteArray())
out.flush()
} finally {
out.closeEntry()
}
classNodeCache.clear()
}
}
}
}
}
}
/*
* This was adapted from code originally written by Pokechu22 in MCInjector
* Link: https://github.com/ModCoderPack/MCInjector/pull/3
*/
class ParameterAnnotationFixer(
private val node: ClassNode,
classVisitor: ClassVisitor?
) : ClassVisitor(Opcodes.ASM9, classVisitor), AsmUtil {
override fun visitEnd() {
super.visitEnd()
class ParameterAnnotationFixer(private val node: ClassNode) : AsmUtil {
fun visitNode() {
val expected = expectedSyntheticParams() ?: return
for (method in node.methods) {
@ -134,7 +137,8 @@ class ParameterAnnotationFixer(
}
method.visibleParameterAnnotations = process(params.size, synthParams.size, method.visibleParameterAnnotations)
method.invisibleParameterAnnotations = process(params.size, synthParams.size, method.invisibleParameterAnnotations)
method.invisibleParameterAnnotations =
process(params.size, synthParams.size, method.invisibleParameterAnnotations)
method.visibleParameterAnnotations?.let {
method.visibleAnnotableParameterCount = it.size
@ -170,3 +174,104 @@ class ParameterAnnotationFixer(
return true
}
}
class OverrideAnnotationAdder(private val node: ClassNode, private val classNodeCache: ClassNodeCache) : AsmUtil {
fun visitNode() {
val superMethods = collectSuperMethods(node)
val disqualifiedMethods = Opcodes.ACC_STATIC or Opcodes.ACC_PRIVATE
for (method in node.methods) {
if (method.access in disqualifiedMethods) {
continue
}
if (method.name == "<init>" || method.name == "<clinit>") {
continue
}
val (name, desc) = SyntheticUtil.findBaseMethod(method, node.name)
if (method.name + method.desc in superMethods) {
val targetMethod = node.methods.firstOrNull { it.name == name && it.desc == desc } ?: method
if (targetMethod.invisibleAnnotations == null) {
targetMethod.invisibleAnnotations = arrayListOf()
}
val annoClass = "Ljava/lang/Override;"
if (targetMethod.invisibleAnnotations.none { it.desc == annoClass }) {
targetMethod.invisibleAnnotations.add(AnnotationNode(annoClass))
}
}
}
}
private fun collectSuperMethods(node: ClassNode): Set<String> {
fun collectSuperMethods(node: ClassNode, superMethods: HashSet<String>) {
val supers = listOfNotNull(node.superName, *node.interfaces.toTypedArray())
if (supers.isEmpty()) {
return
}
val disqualifiedMethods = Opcodes.ACC_STATIC or Opcodes.ACC_PRIVATE
val superNodes = supers.mapNotNull { classNodeCache.findClass(it) }
superNodes.asSequence()
.flatMap { classNode -> classNode.methods.asSequence() }
.filter { method -> method.access !in disqualifiedMethods }
.filter { method -> method.name != "<init>" && method.name != "<clinit>" }
.map { method -> method.name + method.desc }
.toCollection(superMethods)
for (superNode in superNodes) {
collectSuperMethods(superNode, superMethods)
}
}
val result = hashSetOf<String>()
collectSuperMethods(node, result)
return result
}
}
class ClassNodeCache(private val jarFile: JarFile, private val fallbackJar: JarFile) {
private val classNodeMap = hashMapOf<String, ClassNode?>()
fun findClass(name: String): ClassNode? {
return classNodeMap.computeIfAbsent(normalize(name)) { fileName ->
val classData = findClassData(fileName) ?: return@computeIfAbsent null
val classReader = ClassReader(classData)
val node = ClassNode(Opcodes.ASM9)
classReader.accept(node, 0)
return@computeIfAbsent node
}
}
private fun findClassData(className: String): ByteArray? {
val entry = ZipEntry(className)
return (jarFile.getInputStream(entry) // remapped class
?: fallbackJar.getInputStream(entry) // library class
?: ClassLoader.getSystemResourceAsStream(className))?.use { it.readBytes() } // JDK class
}
private fun normalize(name: String): String {
var workingName = name
if (workingName.endsWith(".class")) {
workingName = workingName.substring(0, workingName.length - 6)
}
var startIndex = 0
var endIndex = workingName.length
if (workingName.startsWith('L')) {
startIndex = 1
}
if (workingName.endsWith(';')) {
endIndex--
}
return workingName.substring(startIndex, endIndex).replace('.', '/') + ".class"
}
fun clear() {
classNodeMap.clear()
}
}

View file

@ -23,6 +23,7 @@
package io.papermc.paperweight.tasks
import io.papermc.paperweight.util.AsmUtil
import io.papermc.paperweight.util.SyntheticUtil
import io.papermc.paperweight.util.defaultOutput
import io.papermc.paperweight.util.file
import org.gradle.api.file.RegularFileProperty
@ -288,76 +289,13 @@ object SyntheticMethods {
private val methods: MutableList<Data>
) : MethodNode(Opcodes.ASM9, access, name, descriptor, signature, exceptions) {
private enum class State {
IN_PARAMS,
INVOKE,
RETURN,
OTHER_INSN
}
// This tries to match the behavior of SpecialSource2's SyntheticFinder.addSynthetics() method
override fun visitEnd() {
var state = State.IN_PARAMS
var nextLvt = 0
val (baseName, baseDesc) = SyntheticUtil.findBaseMethod(this, className)
var invokeInsn: MethodInsnNode? = null
loop@for (insn in instructions) {
if (insn is LabelNode || insn is LineNumberNode || insn is TypeInsnNode) {
continue
}
if (state == State.IN_PARAMS) {
if (insn !is VarInsnNode || insn.`var` != nextLvt) {
state = State.INVOKE
}
}
when (state) {
State.IN_PARAMS -> {
nextLvt++
if (insn.opcode == Opcodes.LLOAD || insn.opcode == Opcodes.DLOAD) {
nextLvt++
}
}
State.INVOKE -> {
// Must be a virtual or interface invoke instruction
if ((insn.opcode != Opcodes.INVOKEVIRTUAL && insn.opcode != Opcodes.INVOKEINTERFACE) || insn !is MethodInsnNode) {
return
}
invokeInsn = insn
state = State.RETURN
}
State.RETURN -> {
// The next instruction must be a return
if (insn.opcode !in Opcodes.IRETURN..Opcodes.RETURN) {
return
}
state = State.OTHER_INSN
}
State.OTHER_INSN -> {
// We shouldn't see any other instructions
return
}
}
if (baseName != name || baseDesc != desc) {
// Add this method as a synthetic for baseName
methods += Data(className, baseDesc, name, baseName)
}
val invoke = invokeInsn ?: return
// Must be a method in the same class with a different signature
if (className != invoke.owner || name == invoke.name || desc == invoke.desc) {
return
}
// The descriptors need to be the same size
if (Type.getArgumentTypes(desc).size != Type.getArgumentTypes(invoke.desc).size) {
return
}
// Add this method as a synthetic accessor for insn.name
methods += Data(className, invoke.desc, name, invoke.name)
}
}

View file

@ -24,6 +24,6 @@ package io.papermc.paperweight.util
interface AsmUtil {
operator fun Int.contains(value: Int): Boolean {
return value and this == value
return value and this != 0
}
}

View file

@ -0,0 +1,95 @@
package io.papermc.paperweight.util
import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import org.objectweb.asm.tree.LabelNode
import org.objectweb.asm.tree.LineNumberNode
import org.objectweb.asm.tree.MethodInsnNode
import org.objectweb.asm.tree.MethodNode
import org.objectweb.asm.tree.TypeInsnNode
import org.objectweb.asm.tree.VarInsnNode
object SyntheticUtil : AsmUtil {
fun findBaseMethod(node: MethodNode, className: String): MethodDesc {
if (node.access !in Opcodes.ACC_SYNTHETIC) {
return MethodDesc(node.name, node.desc)
}
return checkMethodNode(node, className) ?: MethodDesc(node.name, node.desc)
}
private enum class State {
IN_PARAMS,
INVOKE,
RETURN,
OTHER_INSN
}
// This tries to match the behavior of SpecialSource2's SyntheticFinder.addSynthetics() method
private fun checkMethodNode(node: MethodNode, className: String): MethodDesc? {
var state = State.IN_PARAMS
var nextLvt = 0
var invokeInsn: MethodInsnNode? = null
loop@for (insn in node.instructions) {
if (insn is LabelNode || insn is LineNumberNode || insn is TypeInsnNode) {
continue
}
if (state == State.IN_PARAMS) {
if (insn !is VarInsnNode || insn.`var` != nextLvt) {
state = State.INVOKE
}
}
when (state) {
State.IN_PARAMS -> {
nextLvt++
if (insn.opcode == Opcodes.LLOAD || insn.opcode == Opcodes.DLOAD) {
nextLvt++
}
}
State.INVOKE -> {
// Must be a virtual or interface invoke instruction
if ((insn.opcode != Opcodes.INVOKEVIRTUAL && insn.opcode != Opcodes.INVOKEINTERFACE) || insn !is MethodInsnNode) {
return null
}
invokeInsn = insn
state = State.RETURN
}
State.RETURN -> {
// The next instruction must be a return
if (insn.opcode !in Opcodes.IRETURN..Opcodes.RETURN) {
return null
}
state = State.OTHER_INSN
}
State.OTHER_INSN -> {
// We shouldn't see any other instructions
return null
}
}
}
val invoke = invokeInsn ?: return null
// Must be a method in the same class with a different signature
if (className != invoke.owner || (node.name == invoke.name && node.desc == invoke.desc)) {
return null
}
// The descriptors need to be the same size
if (Type.getArgumentTypes(node.desc).size != Type.getArgumentTypes(invoke.desc).size) {
return null
}
// Add this method as a synthetic accessor for insn.name
return MethodDesc(invoke.name, invoke.desc)
}
}
data class MethodDesc(val name: String, val desc: String)