-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add heat #206
add heat #206
Conversation
examples/heat/heat_trainer.py
Outdated
class MultiStepLR: | ||
def __init__(self, optimizer, milestones, gamma=0.1): | ||
self.optimizer = optimizer | ||
self.milestones = milestones | ||
self.gamma = gamma | ||
self.last_epoch = 0 | ||
|
||
def step(self): | ||
if self.last_epoch in self.milestones: | ||
self.optimizer.lr = self.optimizer.lr * self.gamma | ||
|
||
self.last_epoch += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分代码似乎就是设置优化器的学习率在特定epoch衰减,在 tensorlayerx
中有该功能的平替 tlx.optimizers.lr.ExponentiaDecay
, 具体调用方式可以参考 examples/cogsl/cogsl_trainer.py
的205行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExponentiaDecay好像不能设置特定的epoch,但是MultiStepDecay可以,所以改成了使用MultiStepDecay可以吗?
examples/heat/heat_trainer.py
Outdated
val_loss = [] | ||
train_loss = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没必要把 val_loss
和 train_loss
的值保存下来,在每个 epoch
中输出计算得到的 train_loss
和 val_loss
即可
examples/heat/heat_trainer.py
Outdated
train_net = HEAT(args.hist_length, args.in_channels_node, args.out_channels, args.out_length, | ||
args.in_channels_edge_attr,args.in_channels_edge_type, args.edge_attr_emb_size, | ||
args.edge_type_emb_size, args.node_emb_size, args.heads, args.concat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里把网络结构叫做 train_net
容易让人误解,建议改名为 net
# print(epoch) | ||
train_loss_epo = 0.0 | ||
train_net.set_train() | ||
for i, data in enumerate(trainDataloader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果 i
变量在后面不会用到的话,可以写为 for _, data in enumrate(trainDataloader)
, 第65行代码也是
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量 i和 j在后面都用于计算平均损失了
examples/heat/heat_trainer.py
Outdated
|
||
train_loss_epo += tlx.convert_to_numpy(loss_each_data) | ||
|
||
train_loss_epo = round(train_loss_epo * 0.3048 / (i + 1), 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里将loss值乘 0.3048
的作用是什么呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
数据集中的单位是英尺,这里是要转换为以米为单位
examples/heat/heat_trainer.py
Outdated
scheduler.step() | ||
|
||
# save model | ||
train_net.save_weights(str(val_loss_epoch) + '.npz', format='npz_dict') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要每一个epoch都保存一个模型权重,只需要保存模型效果最好的权重即可,可以参考其他模型 trainer
文件的实现。一般是根据模型在验证集上的效果,保存一个最好的权重文件,然后在测试集中读取该权重文件,并计算在测试集上的效果,并在控制台进行打印输出。
gammagl/models/heat.py
Outdated
self.concat = concat | ||
self.op = tlx.nn.Linear(in_features=4 * self.hist_length, out_features=self.in_channels_node, W_init='xavier_uniform') | ||
self.op2 = tlx.nn.Linear(in_features=self.out_channels, out_features=self.out_length * 2) | ||
self.leaky_relu = tlx.nn.LeakyReLU(0.1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LeakyReLU
的值也可以作为一个超参数在模型初始化时进行定义
gammagl/layers/conv/heat_conv.py
Outdated
# device = tlx.set_device(device='cpu') | ||
device = tlx.set_device(device='GPU', id=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只需要在 heat_trainer.py
中设置运行的device即可,不需要在其他文件中重复写
gammagl/datasets/ngsim.py
Outdated
os.environ['TL_BACKEND'] = 'torch' | ||
|
||
device = tlx.set_device(device='GPU', id=0) | ||
# device = tlx.set_device(device='cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同样没必要在这里设置
examples/heat/heat_trainer.py
Outdated
print("Epoch [{:0>3d}] ".format(epoch + 1) + " train loss: {:.4f}".format( | ||
train_loss[epoch]) + " val loss: {:.4f}".format(val_loss[epoch])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
模型需要有一个评价标准,来评价模型的好坏,比如acc
, auc
等,不能简单输出loss值,而且你的数据集好像有测试集,也需要你在测试集上跑一下。
gammagl/datasets/ngsim.py
Outdated
class NGSIM_US_101(Dataset): | ||
def __init__(self, data_path, hist_len=10, fut_len=10, save_to=''): | ||
super(NGSIM_US_101).__init__() | ||
self.data_path = data_path | ||
self.hist_len = hist_len | ||
self.gut_len = fut_len | ||
self.save_to = save_to | ||
self.url = 'https://raw.githubusercontent.com/gjy1221/NGSIM-US-101/main/data/data.zip' | ||
self.data_names = os.listdir('{}'.format(self.data_path)) | ||
print(self.data_path) | ||
|
||
def __len__(self): | ||
return len(self.data_names) | ||
|
||
def __getitem__(self, index): | ||
# name = self.data_names[index] | ||
# for i in range(len(self.data_names)): | ||
# self.data_names[i] = self.data_names[i].split(".")[0] | ||
# print(self.data_names[i]) | ||
data_item = tlx.files.load_npy_to_any(self.data_path, self.data_names[index]) | ||
data_item.edge_attr = data_item.edge_attr.transpose(0, 1) | ||
data_item.edge_type = data_item.edge_type.transpose(0, 1) | ||
# print("dataset_shape:", data_item.x.shape, data_item.edge_attr.shape) | ||
return data_item | ||
|
||
def download(self): | ||
path = download_url(self.url, self.save_to) | ||
with zipfile.ZipFile(path, 'r') as zip_ref: | ||
# 解压缩所有文件 | ||
zip_ref.extractall(self.save_to) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 trainer
文件中调用该数据集时,不用显式调用 download
方法,该数据集类继承于 Dataset
类,在这个类的 __init__
方法中,会执行download和process方法,具体可以看一下 gammagl/data/dataset.py
对该类的实现细节,另外可以在该数据集中重写父类的 process
方法,自动对数据集进行处理。
examples/heat/heat_trainer.py
Outdated
os.environ["OMP_NUM_THREADS"] = "4" | ||
os.environ['TL_BACKEND'] = 'torch' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该把设置后端的代码放在 import gammagl
之前,否则后端设置不会起作用
examples/heat/heat_trainer.py
Outdated
print('loading HEAT model') | ||
|
||
net.to(args.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要在这设置 net.to(args.device)
在超参数中 --device
即可进行设置
examples/heat/heat_trainer.py
Outdated
for j, data in enumerate(valDataloader): | ||
logits = net(data.x, data.edge_index, data.edge_attr, data.edge_type) | ||
data.y = data.y[:, 0:args.out_length, :] | ||
data.y = data.y.view(data.y.shape[0], -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
view
接口方法在 mindspore
后端下不适用,可以使用 tlx.reshape()
接口
examples/heat/heat_trainer.py
Outdated
for i, data in enumerate(testDataloader): | ||
logits = net(data.x, data.edge_index, data.edge_attr, data.edge_type) | ||
data.y = data.y[:, 0:args.out_length, :] | ||
data.y = data.y.view(data.y.shape[0], -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
gammagl/layers/conv/heat_conv.py
Outdated
out = tlx.reduce_sum(tlx.multiply(alpha, out), axis=1) | ||
|
||
if self.concat: | ||
out = out.view(out.shape[0], -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
examples/heat/heat_trainer.py
Outdated
|
||
for i, data in enumerate(testDataloader): | ||
logits = net(data.x, data.edge_index, data.edge_attr, data.edge_type) | ||
data.y = data.y[:, 0:args.out_length, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
examples/heat/heat_trainer.py
Outdated
# Convert predictions to numpy arrays | ||
predictions = tlx.convert_to_numpy(test_logits) | ||
ground_truth = tlx.convert_to_numpy(test_y) | ||
|
||
# Calculate Euclidean distance | ||
distance =np.sqrt(np.sum(np.square(predictions - ground_truth))) | ||
total_distance += distance | ||
total_samples += len(predictions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的计算使用 tlx
的接口也可进行替换
examples/heat/heat_trainer.py
Outdated
parser.add_argument("--dropout", type=int, default=0.5, help="dropout rate") | ||
parser.add_argument("--leaky_rate", type=int, default=0.1, help="LeakyReLU rate") | ||
|
||
parser.add_argument("--lr", type=int, default=0.001, help="learning rate") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type
应该是 float
gammagl/layers/conv/heat_conv.py
Outdated
in_channels_node : int | ||
Size of each input node feature. | ||
in_channels_edge_attr : int | ||
Size of each input edge attribute. | ||
in_channels_edge_type : int | ||
Size of each input edge type. | ||
node_emb_size : int | ||
Size of the node embedding. | ||
edge_attr_emb_size : int | ||
Size of the edge attribute embedding. | ||
edge_type_emb_size : int | ||
Size of the edge type embedding. | ||
out_channels : int | ||
Size of each output node feature. | ||
heads : int, optional | ||
Number of attention heads. (default: 3) | ||
concat : bool, optional | ||
If set to False, the multi-head attentions are averaged instead of concatenated. (default: True) | ||
|
||
""" | ||
|
||
""" | ||
1.consider 2 types of nodes: v, p | ||
2.consider 4 types of edges: v->v, v->p, p->v, p->p | ||
3.assume that different nodes have the same dimension, but different vector space. | ||
4.assume that different edges have the same dimension, but different vector space. | ||
""" | ||
|
||
def __init__(self, in_channels_node=64, in_channels_edge_attr=5, in_channels_edge_type=4, node_emb_size=64, | ||
edge_attr_emb_size=64, edge_type_emb_size=64, out_channels=128, heads=3, concat=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有默认值的输入参数,记得在 rst 文档中标明 optional
和默认值
gammagl/models/heat.py
Outdated
concat: bool, optional | ||
If set to `True`, the multi-head attentions are concatenated. Otherwise, they are averaged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 init
中 concat
并没有默认值,是否可以设置为 optional
?
实现了“Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction” (https://arxiv.org/abs/2106.07161 )中的HEAT算法;在datasets中添加了NGSIM US-101数据集,数据已完成预处理,适用于HEAT算法