Skip to content
This repository was archived by the owner on Mar 24, 2025. It is now read-only.

rename #258

Merged
merged 3 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

package org.bitlap.tools.internal

import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
import scala.annotation.tailrec
import scala.reflect.macros.whitebox

Expand Down Expand Up @@ -56,7 +54,6 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
def impl(annottees: Expr[Any]*): Expr[Any] = {
checkAnnottees(annottees)
val resTree = collectCustomExpr(annottees)(createCustomExpr)
printTree(force = false, resTree.tree)
resTree
}

Expand All @@ -74,26 +71,13 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
*/
def evalTree[T: WeakTypeTag](tree: Tree): T = c.eval(c.Expr[T](c.untypecheck(tree.duplicate)))

/** Output ast result.
*
* @param force
* @param resTree
*/
def printTree(force: Boolean, resTree: Tree): Unit =
c.info(
c.enclosingPosition,
s"\n###### Time: ${ZonedDateTime.now().format(DateTimeFormatter.ISO_ZONED_DATE_TIME)} " +
s"Expanded macro start ######\n" + resTree.toString() + "\n###### Expanded macro end ######\n",
force = false
)

/** Check the class and its companion object, and return the class definition.
*
* @param annottees
* @return
* Return a [[scala.reflect.api.Trees#ClassDef]]
*/
def checkGetClassDef(annottees: Seq[Expr[Any]]): ClassDef =
def checkClassDef(annottees: Seq[Expr[Any]]): ClassDef =
annottees.map(_.tree).toList match {
case (classDecl: ClassDef) :: Nil => classDecl
case (classDecl: ClassDef) :: (_: ModuleDef) :: Nil => classDecl
Expand All @@ -106,7 +90,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Return a optional [[scala.reflect.api.Trees#ModuleDef]]
*/
def getModuleDefOption(annottees: Seq[Expr[Any]]): Option[ModuleDef] =
def moduleDef(annottees: Seq[Expr[Any]]): Option[ModuleDef] =
annottees.map(_.tree).toList match {
case (moduleDef: ModuleDef) :: Nil => Some(moduleDef)
case (_: ClassDef) :: Nil => None
Expand All @@ -126,8 +110,8 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
def collectCustomExpr(
annottees: Seq[Expr[Any]]
)(modifyAction: (ClassDef, Option[ModuleDef]) => Any): Expr[Nothing] = {
val classDef = checkGetClassDef(annottees)
val compDecl = getModuleDefOption(annottees)
val classDef = checkClassDef(annottees)
val compDecl = moduleDef(annottees)
modifyAction(classDef, compDecl).asInstanceOf[Expr[Nothing]]
}

Expand All @@ -149,7 +133,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Return false if mods exists `private[this]` or `protected[this]`
*/
def isNotLocalClassMember(tree: Tree): Boolean = {
def nonLocalMember(tree: Tree): Boolean = {
lazy val modifierNotLocal = (mods: Modifiers) =>
!(
mods.hasFlag(Flag.PRIVATE | Flag.LOCAL) | mods.hasFlag(Flag.PROTECTED | Flag.LOCAL)
Expand All @@ -167,7 +151,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Return a sequence of [[scala.reflect.api.Trees#Tree]], each one is `tname: tpt`
*/
def getConstructorParamsNameWithType(annotteeClassParams: Seq[Tree]): Seq[Tree] =
def classParamsTermNameWithType(annotteeClassParams: Seq[Tree]): Seq[Tree] =
annotteeClassParams.map(_.asInstanceOf[ValDef]).map(v => q"${v.name}: ${v.tpt}")

/** Modify companion object or object.
Expand All @@ -193,7 +177,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Return a sequence of [[scala.reflect.api.Trees#ValDef]]
*/
def getClassMemberValDefs(annotteeClassDefinitions: Seq[Tree]): Seq[ValDef] =
def classValDefs(annotteeClassDefinitions: Seq[Tree]): Seq[ValDef] =
annotteeClassDefinitions
.filter(_ match {
case _: ValDef => true
Expand All @@ -207,7 +191,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Return a sequence of [[scala.reflect.api.Trees#ValDef]]
*/
def getClassConstructorValDefsFlatten(annotteeClassParams: List[List[Tree]]): Seq[ValDef] =
def classConstructorValDefs(annotteeClassParams: List[List[Tree]]): Seq[ValDef] =
annotteeClassParams.flatten.map(_.asInstanceOf[ValDef])

/** Extract the constructor params [[scala.reflect.api.Trees#ValDef]] not flatten.
Expand All @@ -216,7 +200,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Return a double sequence of [[scala.reflect.api.Trees#ValDef]]
*/
def getClassConstructorValDefsNotFlatten(annotteeClassParams: List[List[Tree]]): Seq[Seq[ValDef]] =
def classConstructorValDefss(annotteeClassParams: List[List[Tree]]): Seq[Seq[ValDef]] =
annotteeClassParams.map(_.map(_.asInstanceOf[ValDef]))

/** Extract the methods belonging to the class, contains Secondary Constructor.
Expand All @@ -225,7 +209,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Return a sequence of [[scala.reflect.api.Trees#DefDef]]
*/
def getClassMemberDefDefs(annotteeClassDefinitions: Seq[Tree]): Seq[DefDef] =
def classMemberDefDefs(annotteeClassDefinitions: Seq[Tree]): Seq[DefDef] =
annotteeClassDefinitions
.filter(_ match {
case _: DefDef => true
Expand All @@ -243,8 +227,8 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @example
* Return a tree, such as `new TestClass12(i)(j)(k)(t)`
*/
def getConstructorWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], isCase: Boolean): Tree = {
val fieldssValDefNotFlatten = getClassConstructorValDefsNotFlatten(fieldss)
def curriedConstructor(typeName: TypeName, fieldss: List[List[Tree]], isCase: Boolean): Tree = {
val fieldssValDefNotFlatten = classConstructorValDefss(fieldss)
val allFieldsTermName = fieldssValDefNotFlatten.map(_.map(_.name.toTermName))
// not currying
val constructor = if (fieldss.isEmpty || fieldss.size == 1) {
Expand All @@ -266,7 +250,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Scala type name
*/
def toScalaType(javaType: String): String = {
def java2ScalaType(javaType: String): String = {
val types = Map(
"java.lang.Integer" -> "Int",
"java.lang.Long" -> "Long",
Expand All @@ -288,7 +272,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @return
* Return a sequence of [[scala.reflect.api.Names#TypeName]]
*/
def extractClassTypeParamsTypeName(tpParams: List[Tree]): List[TypeName] =
def typeParams(tpParams: List[Tree]): List[TypeName] =
tpParams.map(_.asInstanceOf[TypeDef].name)

/** Is there a parent class? Does not contains sdk class, such as AnyRef and Object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ object applyMacro {
fieldss: List[List[Tree]],
classTypeParams: List[Tree]
): Tree = {
val allFieldsTermName = fieldss.map(f => getConstructorParamsNameWithType(f))
val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams)
val allFieldsTermName = fieldss.map(f => classParamsTermNameWithType(f))
val returnTypeParams = typeParams(classTypeParams)
// not currying
val applyMethod = if (fieldss.isEmpty || fieldss.size == 1) {
q"def apply[..$classTypeParams](..${allFieldsTermName.flatten}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase = false)}"
q"def apply[..$classTypeParams](..${allFieldsTermName.flatten}): $typeName[..$returnTypeParams] = ${curriedConstructor(typeName, fieldss, isCase = false)}"
} else {
// currying
val first = allFieldsTermName.head
q"def apply[..$classTypeParams](..$first)(...${allFieldsTermName.tail}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase = false)}"
q"def apply[..$classTypeParams](..$first)(...${allFieldsTermName.tail}): $typeName[..$returnTypeParams] = ${curriedConstructor(typeName, fieldss, isCase = false)}"
}
applyMethod
}
Expand All @@ -78,7 +78,7 @@ object applyMacro {

override def checkAnnottees(annottees: Seq[c.universe.Expr[Any]]): Unit = {
super.checkAnnottees(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
val annotateeClass: ClassDef = checkClassDef(annottees)
if (isCaseClass(annotateeClass)) {
c.abort(c.enclosingPosition, ErrorMessage.ONLY_CASE_CLASS)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ object builderMacro {

private def getFieldSetMethod(typeName: TypeName, field: Tree, classTypeParams: List[Tree]): Tree = {
val builderClassName = getBuilderClassName(typeName)
val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams)
val returnTypeParams = typeParams(classTypeParams)
lazy val valDefMapTo = (v: ValDef) => q"""
def ${v.name}(${v.name}: ${v.tpt}): $builderClassName[..$returnTypeParams] = {
this.${v.name} = ${v.name}
Expand All @@ -64,7 +64,7 @@ object builderMacro {
val builderClassName = getBuilderClassName(typeName)
val builderFieldMethods = fields.map(f => getFieldSetMethod(typeName, f, classTypeParams))
val builderFieldDefinitions = fields.map(f => getFieldDefinition(f))
val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams)
val returnTypeParams = typeParams(classTypeParams)
val builderMethod =
q"def builder[..$classTypeParams](): $builderClassName[..$returnTypeParams] = new $builderClassName()"
val buulderClass =
Expand All @@ -75,7 +75,7 @@ object builderMacro {

..$builderFieldMethods

def build(): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase)}
def build(): $typeName[..$returnTypeParams] = ${curriedConstructor(typeName, fieldss, isCase)}
}
"""
List(builderMethod, buulderClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object constructorMacro {
}

private def getMutableValDefAndExcludeFields(annotteeClassDefinitions: Seq[Tree]): Seq[c.universe.ValDef] =
getClassMemberValDefs(annotteeClassDefinitions).filter(v =>
classValDefs(annotteeClassDefinitions).filter(v =>
v.mods.hasFlag(Flag.MUTABLE) &&
!extractOptions.contains(v.name.decodedName.toString)
)
Expand All @@ -54,7 +54,7 @@ object constructorMacro {
getMutableValDefAndExcludeFields(annotteeClassDefinitions).map { v =>
if (v.tpt.isEmpty) { // val i = 1, tpt is `<type ?>`
// TODO getClass RETURN a java type, maybe we can try use class reflect to get the fields type name.
q"${v.name}: ${TypeName(toScalaType(evalTree(v.rhs).getClass.getTypeName))}"
q"${v.name}: ${TypeName(java2ScalaType(evalTree(v.rhs).getClass.getTypeName))}"
} else {
q"${v.name}: ${v.tpt}"
}
Expand All @@ -76,9 +76,9 @@ object constructorMacro {
}
// Extract the internal fields of members belonging to the class, but not in primary constructor.
val annotteeClassFieldNames = getMutableValDefAndExcludeFields(annotteeClassDefinitions).map(_.name)
val allFieldsTermName = getClassConstructorValDefsNotFlatten(annotteeClassParams).map(_.map(_.name.toTermName))
val allFieldsTermName = classConstructorValDefss(annotteeClassParams).map(_.map(_.name.toTermName))
// Extract the field of the primary constructor.
val classParamsNameWithType = getConstructorParamsNameWithType(annotteeClassParams.flatten)
val classParamsNameWithType = classParamsTermNameWithType(annotteeClassParams.flatten)
val applyMethod = if (annotteeClassParams.isEmpty || annotteeClassParams.size == 1) {
q"""
def this(..${classParamsNameWithType ++ classInternalFieldsWithType}) = {
Expand All @@ -88,7 +88,7 @@ object constructorMacro {
"""
} else {
// NOTE: currying constructor overload must be placed in the first bracket block.
val allClassCtorParamsNameWithType = annotteeClassParams.map(cc => getConstructorParamsNameWithType(cc))
val allClassCtorParamsNameWithType = annotteeClassParams.map(cc => classParamsTermNameWithType(cc))
q"""
def this(..${allClassCtorParamsNameWithType.head ++ classInternalFieldsWithType})(...${allClassCtorParamsNameWithType.tail}) = {
this(..${allFieldsTermName.head})(...${allFieldsTermName.tail})
Expand All @@ -110,7 +110,7 @@ object constructorMacro {

override def checkAnnottees(annottees: Seq[c.universe.Expr[Any]]): Unit = {
super.checkAnnottees(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
val annotateeClass: ClassDef = checkClassDef(annottees)
if (isCaseClass(annotateeClass)) {
c.abort(c.enclosingPosition, ErrorMessage.ONLY_CLASS)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ object elapsedMacro {
getNewMethod(defDef)
}
}
printTree(force = false, resTree)
c.Expr[Any](resTree)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object equalsAndHashCodeMacro {

override def checkAnnottees(annottees: Seq[c.universe.Expr[Any]]): Unit = {
super.checkAnnottees(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
val annotateeClass: ClassDef = checkClassDef(annottees)
if (isCaseClass(annotateeClass)) {
c.abort(c.enclosingPosition, ErrorMessage.ONLY_CLASS)
}
Expand All @@ -52,12 +52,12 @@ object equalsAndHashCodeMacro {
/** Extract the internal fields of members belonging to the class.
*/
private def getInternalFieldsTermNameExcludeLocal(annotteeClassDefinitions: Seq[Tree]): Seq[TermName] = {
if (annotteeClassDefinitions.exists(f => isNotLocalClassMember(f))) {
if (annotteeClassDefinitions.exists(f => nonLocalMember(f))) {
c.info(c.enclosingPosition, s"There is a non private class definition inside the class", force = true)
}
getClassMemberValDefs(annotteeClassDefinitions)
classValDefs(annotteeClassDefinitions)
.filter(p =>
isNotLocalClassMember(p) &&
nonLocalMember(p) &&
!extractOptions.contains(p.name.decodedName.toString)
)
.map(_.name.toTermName)
Expand All @@ -70,7 +70,7 @@ object equalsAndHashCodeMacro {
superClasses: Seq[Tree],
annotteeClassDefinitions: Seq[Tree]
): List[Tree] = {
val existsCanEqual = getClassMemberDefDefs(annotteeClassDefinitions).exists {
val existsCanEqual = classMemberDefDefs(annotteeClassDefinitions).exists {
case defDef: DefDef if defDef.name.decodedName.toString == "canEqual" && defDef.vparamss.nonEmpty =>
val safeValDefs = valDefAccessors(defDef.vparamss.flatten)
safeValDefs.exists(_.paramType.toString == "Any") && safeValDefs.exists(_.name.decodedName.toString == "that")
Expand Down Expand Up @@ -113,8 +113,8 @@ object equalsAndHashCodeMacro {

override def createCustomExpr(classDecl: ClassDef, compDeclOpt: Option[ModuleDef]): Any = {
lazy val map = (classDefinition: ClassDefinition) =>
getClassConstructorValDefsFlatten(classDefinition.classParamss)
.filter(cf => isNotLocalClassMember(cf))
classConstructorValDefs(classDefinition.classParamss)
.filter(cf => nonLocalMember(cf))
.map(_.name.toTermName) ++
getInternalFieldsTermNameExcludeLocal(classDefinition.body)
val classDefinition = mapToClassDeclInfo(classDecl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ object javaCompatibleMacro {

override def checkAnnottees(annottees: Seq[c.universe.Expr[Any]]): Unit = {
super.checkAnnottees(annottees)
val annotateeClass: ClassDef = checkGetClassDef(annottees)
val annotateeClass: ClassDef = checkClassDef(annottees)
if (!isCaseClass(annotateeClass)) {
c.abort(c.enclosingPosition, ErrorMessage.ONLY_CASE_CLASS)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,26 @@ object logMacro {
}
val newClass = extractOptions match {
case ScalaLoggingLazy | ScalaLoggingStrict =>
appendImplDefSuper(checkGetClassDef(annottees), _ => List(logTree(annottees)))
appendImplDefSuper(checkClassDef(annottees), _ => List(logTree(annottees)))
case _ =>
prependImplDefBody(checkGetClassDef(annottees), _ => List(logTree(annottees)))
prependImplDefBody(checkClassDef(annottees), _ => List(logTree(annottees)))
}
val moduleDef = getModuleDefOption(annottees)
val md = moduleDef(annottees)
q"""
${if (moduleDef.isEmpty) EmptyTree else moduleDef.get}
${if (md.isEmpty) EmptyTree else md.get}
$newClass
"""
case (_: ModuleDef) :: _ =>
extractOptions match {
case ScalaLoggingLazy | ScalaLoggingStrict =>
appendImplDefSuper(getModuleDefOption(annottees).get, _ => List(logTree(annottees)))
case _ => prependImplDefBody(getModuleDefOption(annottees).get, _ => List(logTree(annottees)))
appendImplDefSuper(moduleDef(annottees).get, _ => List(logTree(annottees)))
case _ => prependImplDefBody(moduleDef(annottees).get, _ => List(logTree(annottees)))
}
// Note: If a class is annotated and it has a companion, then both are passed into the macro.
// (But not vice versa - if an object is annotated and it has a companion class, only the object itself is expanded).
// see https://docs.scala-lang.org/overviews/macros/annotations.html
}

printTree(force = false, resTree)
c.Expr[Any](resTree)
}
}
Expand Down
Loading