Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add glnn #205

Merged
merged 8 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ Now, GammaGL supports about 60 models, we welcome everyone to use or contribute
| [GGD [NeurIPS 2022]](./examples/ggd) | | :heavy_check_mark: | | :heavy_check_mark: |
| [LTD [WSDM 2022]](./examples/ltd) | | :heavy_check_mark: | | :heavy_check_mark: |
| [Graphormer [NeurIPS 2021]](./examples/graphormer) | | :heavy_check_mark: | | :heavy_check_mark: |
| [HiD-Net [AAAI 2023]](./examples/hid_net) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [HiD-Net [AAAI 2023]](./examples/hid_net) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [FusedGAT [MLSys 2022]](./examples/fusedgat) | | :heavy_check_mark: | | |
| [GLNN [ICLR 2022]](./examples/glnn) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |


| Contrastive Learning | TensorFlow | PyTorch | Paddle | MindSpore |
Expand Down
3 changes: 2 additions & 1 deletion docs/source/api/gammagl.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ gammagl.utils
gammagl.utils.negative_sampling
gammagl.utils.to_scipy_sparse_matrix
gammagl.utils.read_embeddings
gammagl.utils.homophily
gammagl.utils.homophily
gammagl.utils.get_train_val_test_split
41 changes: 41 additions & 0 deletions examples/glnn/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Graph-less Neural Networks (GLNN)

- Paper link: [https://arxiv.org/pdf/2110.08727](https://arxiv.org/pdf/2110.08727)
- Author's code repo: [https://github.com/snap-research/graphless-neural-networks](https://github.com/snap-research/graphless-neural-networks)

# Dataset Statics
| Dataset | # Nodes | # Edges | # Classes |
| -------- | ------- | ------- | --------- |
| Cora | 2,708 | 10,556 | 7 |
| Citeseer | 3,327 | 9,228 | 6 |
| Pubmed | 19,717 | 88,651 | 3 |
| Computers| 13,752 | 491,722 | 10 |
| Photo | 7,650 | 238,162 | 8 |

Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid), [Amazon](https://gammagl.readthedocs.io/en/latest/generated/gammagl.datasets.Amazon.html#gammagl.datasets.Amazon).

# Results

- Available dataset: "cora", "citeseer", "pubmed", "computers", "photo"
- Available teacher: "SAGE", "GCN", "GAT", "APPNP", "MLP"

```bash
TL_BACKEND="tensorflow" python train_teacher.py --dataset cora --teacher SAGE
TL_BACKEND="tensorflow" python train_student.py --dataset cora --teacher SAGE
TL_BACKEND="torch" python train_teacher.py --dataset cora --teacher SAGE
TL_BACKEND="torch" python train_student.py --dataset cora --teacher SAGE
TL_BACKEND="paddle" python train_teacher.py --dataset cora --teacher SAGE
TL_BACKEND="paddle" python train_student.py --dataset cora --teacher SAGE
TL_BACKEND="mindspore" python train_teacher.py --dataset cora --teacher SAGE
TL_BACKEND="mindspore" python train_student.py --dataset cora --teacher SAGE
```

| Dataset | Paper | Our(tf) | Our(th) | Our(pd) | Our(ms) |
| --------- | ---------- | ---------- | ---------- | ---------- | ---------- |
| Cora | 80.54±1.35 | 80.94±0.31 | 80.84±0.30 | 80.90±0.21 | 81.04±0.30 |
| Citeseer | 71.77±2.01 | 70.74±0.87 | 71.34±0.55 | 71.18±1.20 | 70.58±1.14 |
| Pubmed | 75.42±2.31 | 77.90±0.07 | 77.88±0.23 | 77.78±0.19 | 77.78±0.13 |
| Computers | 83.03±1.87 | 83.45±0.61 | 82.78±0.47 | 83.03±0.14 | 83.40±0.45 |
| Photo | 92.11±1.08 | 91.93±0.16 | 91.91±0.24 | 91.89±0.27 | 91.88±0.21 |

- The model performance is the average of 5 tests
140 changes: 140 additions & 0 deletions examples/glnn/train.conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
global:
num_layers: 2
hidden_dim: 128
learning_rate: 0.01

cora:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.01
weight_decay: 0.005
dropout_ratio: 0.6

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01


citeseer:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.01
weight_decay: 0.001
dropout_ratio: 0.1

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01

pubmed:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.005
weight_decay: 0
dropout_ratio: 0.4

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01

computers:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.001
weight_decay: 0.002
dropout_ratio: 0.3

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01

photo:
SAGE:
fan_out: 5,5
learning_rate: 0.01
dropout_ratio: 0
weight_decay: 0.0005

GCN:
hidden_dim: 64
dropout_ratio: 0.8
weight_decay: 0.001

MLP:
learning_rate: 0.005
weight_decay: 0.002
dropout_ratio: 0.3

GAT:
dropout_ratio: 0.6
weight_decay: 0.01
num_heads: 8
attn_dropout_ratio: 0.3

APPNP:
dropout_ratio: 0.5
weight_decay: 0.01
168 changes: 168 additions & 0 deletions examples/glnn/train_student.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# !/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR

import yaml
import argparse
import tensorlayerx as tlx
from gammagl.datasets import Planetoid, Amazon
from gammagl.models import MLP
from gammagl.utils import mask_to_index
from tensorlayerx.model import TrainOneStep, WithLoss


class SemiSpvzLoss(WithLoss):
def __init__(self, net, loss_fn):
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)

def forward(self, data, teacher_logits):
student_logits = self.backbone_network(data['x'])
train_y = tlx.gather(data['y'], data['t_idx'])
train_teacher_logits = tlx.gather(teacher_logits, data['t_idx'])
train_student_logits = tlx.gather(student_logits, data['t_idx'])
loss = self._loss_fn(train_y, train_student_logits, train_teacher_logits, args.lamb)
return loss


def get_training_config(config_path, model_name, dataset):
with open(config_path, "r") as conf:
full_config = yaml.load(conf, Loader=yaml.FullLoader)
dataset_specific_config = full_config["global"]
model_specific_config = full_config[dataset][model_name]

if model_specific_config is not None:
specific_config = dict(dataset_specific_config, **model_specific_config)
else:
specific_config = dataset_specific_config

specific_config["model_name"] = model_name
return specific_config


def calculate_acc(logits, y, metrics):
metrics.update(logits, y)
rst = metrics.result()
metrics.reset()
return rst


def kl_divergence(teacher_logits, student_logits):
# convert logits to probabilities
teacher_probs = tlx.softmax(teacher_logits)
student_probs = tlx.softmax(student_logits)
# compute KL divergence
kl_div = tlx.reduce_sum(teacher_probs * (tlx.log(teacher_probs+1e-10) - tlx.log(student_probs+1e-10)), axis=-1)
return tlx.reduce_mean(kl_div)


def cal_mlp_loss(labels, student_logits, teacher_logits, lamb):
loss_l = tlx.losses.softmax_cross_entropy_with_logits(student_logits, labels)
loss_t = kl_divergence(teacher_logits, student_logits)
return lamb * loss_l + (1 - lamb) * loss_t


def train_student(args):
# load datasets
if str.lower(args.dataset) not in ['cora','pubmed','citeseer','computers','photo']:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
if args.dataset in ['cora', 'pubmed', 'citeseer']:
dataset = Planetoid(args.dataset_path, args.dataset)
elif args.dataset == 'computers':
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=200/13752, val_ratio=(200/13752)*1.5)
elif args.dataset == 'photo':
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=160/7650, val_ratio=(160/7650)*1.5)
graph = dataset[0]

# load teacher_logits from .npy file
teacher_logits = tlx.files.load_npy_to_any(path = r'./', name = f'{args.dataset}_{args.teacher}_logits.npy')
teacher_logits = tlx.ops.convert_to_tensor(teacher_logits)

# for mindspore, it should be passed into node indices
train_idx = mask_to_index(graph.train_mask)
test_idx = mask_to_index(graph.test_mask)
val_idx = mask_to_index(graph.val_mask)
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0)

net = MLP(in_channels=dataset.num_node_features,
hidden_channels=conf["hidden_dim"],
out_channels=dataset.num_classes,
num_layers=conf["num_layers"],
act=tlx.nn.ReLU(),
norm=None,
dropout=float(conf["dropout_ratio"]))

optimizer = tlx.optimizers.Adam(lr=conf["learning_rate"], weight_decay=conf["weight_decay"])
metrics = tlx.metrics.Accuracy()
train_weights = net.trainable_weights

loss_func = SemiSpvzLoss(net, cal_mlp_loss)
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)

data = {
"x": graph.x,
"y": graph.y,
"train_idx": train_idx,
"test_idx": test_idx,
"val_idx": val_idx,
"t_idx": t_idx
}

best_val_acc = 0
for epoch in range(args.n_epoch):
net.set_train()
train_loss = train_one_step(data, teacher_logits)
net.set_eval()
logits = net(data['x'])
val_logits = tlx.gather(logits, data['val_idx'])
val_y = tlx.gather(data['y'], data['val_idx'])
val_acc = calculate_acc(val_logits, val_y, metrics)

print("Epoch [{:0>3d}] ".format(epoch+1)\
+ " train loss: {:.4f}".format(train_loss.item())\
+ " val acc: {:.4f}".format(val_acc))

# save best model on evaluation set
if val_acc > best_val_acc:
best_val_acc = val_acc
net.save_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict')

net.load_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict')
net.set_eval()
logits = net(data['x'])
test_logits = tlx.gather(logits, data['test_idx'])
test_y = tlx.gather(data['y'], data['test_idx'])
test_acc = calculate_acc(test_logits, test_y, metrics)
print("Test acc: {:.4f}".format(test_acc))



if __name__ == '__main__':
# parameters setting
parser = argparse.ArgumentParser()
parser.add_argument("--model_config_path",type=str,default="./train.conf.yaml",help="path to modelconfigeration")
parser.add_argument("--teacher", type=str, default="SAGE", help="teacher model")
parser.add_argument("--lamb", type=float, default=0, help="parameter balances loss from hard labels and teacher outputs")
parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch")
parser.add_argument('--dataset', type=str, default="cora", help="dataset")
parser.add_argument("--dataset_path", type=str, default=r'./data', help="path to save dataset")
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
parser.add_argument("--gpu", type=int, default=0)

args = parser.parse_args()

conf = {}
if args.model_config_path is not None:
conf = get_training_config(args.model_config_path, args.teacher, args.dataset)
conf = dict(args.__dict__, **conf)

if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")

train_student(args)
Loading
Loading