def pred(self, g_data_list, g_data_list_reverse): pred_list = [] idx_list = list(range(len(g_data_list))) self.stage1.eval() batch_idx_list = gen_batch_idx(idx_list, 32) with torch.no_grad(): for batch_idx in batch_idx_list: data_list = [g_data_list[idx] for idx in batch_idx] data_list_reverse = [ g_data_list_reverse[idx] for idx in batch_idx ] batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred = self.stage1(batch_nodes, batch_edge_idx, batch_idx, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() if len(pred.size()) == 0: pred.unsqueeze_(0) pred_list.append(pred) return torch.cat(pred_list, dim=0)
def fit_full_train(self, train_g_data_list, arch_encoding, epoch): meters = MetricLogger(delimiter=" ") self.predictor.train() idx_list = list(range(len(train_g_data_list))) idx_dataset = ProductList(idx_list) training_data = torch.utils.data.DataLoader(idx_dataset, batch_size=self.batch_size, shuffle=True) counter = 0 for i, v in enumerate(training_data): pair1_idx = v[0].cpu().numpy() pair2_idx = v[1].cpu().numpy() counter += len(pair1_idx) data_list_pair1 = [] arch_path_encoding_pair1 = [] data_list_pair2 = [] arch_path_encoding_pair2 = [] for pair_idx in zip(pair1_idx, pair2_idx): idx1, idx2 = pair_idx g_d_1 = train_g_data_list[idx1] data_list_pair1.append(g_d_1) arch_path_encoding_pair1.append(arch_encoding[idx1]) g_d_2 = train_g_data_list[idx2] data_list_pair2.append(g_d_2) arch_path_encoding_pair2.append(arch_encoding[idx2]) dist_gt = torch.tensor([edit_distance_normalization(arch_path_encoding_pair1[i], arch_path_encoding_pair2[i], self.node_num) for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32) batch1 = Batch.from_data_list(data_list_pair1) batch1 = batch1.to(self.device) batch2 = Batch.from_data_list(data_list_pair2) batch2 = batch2.to(self.device) dist_gt = dist_gt.to(self.device) batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2, batch_edge_idx_2, batch_idx_2) prediction = prediction.squeeze(dim=-1) loss = self.criterion(prediction, dist_gt) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() meters.update(loss=loss.item()) save_dir = os.path.join(self.save_dir, f'unsupervised_ss_rl_epoch_{epoch}.pt') if self.save_model: torch.save(self.predictor.state_dict(), save_dir) if self.logger: self.logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))])) else: print(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
def fit_train(self, total_archs, epoch): meters = MetricLogger(delimiter=" ") self.predictor.train() idx_list = list(range(len(total_archs))) idx_list2 = list(idx_list) random.shuffle(idx_list) batch_idx_list_1 = gen_batch_idx_gen(idx_list, self.batch_size, drop_last=True) random.shuffle(idx_list2) batch_idx_list_2 = gen_batch_idx_gen(idx_list2, self.batch_size, drop_last=True) counter = 0 for i, (pair1_idx, pair2_idx) in enumerate(zip(batch_idx_list_1, batch_idx_list_2)): counter += len(pair1_idx) data_list_pair1 = [] arch_path_encoding_pair1 = [] data_list_pair2 = [] arch_path_encoding_pair2 = [] for pair_idx in zip(pair1_idx, pair2_idx): idx1, idx2 = pair_idx arch_1_info = gen_arch_info(total_archs[idx1]) g_d_1 = arch_1_info['g_data'] data_list_pair1.append(g_d_1) arch_path_encoding_pair1.append(arch_1_info['pe_path_enc_aware_vec']) arch_2_info = gen_arch_info(total_archs[idx2]) data_list_pair2.append(arch_2_info['g_data']) arch_path_encoding_pair2.append(arch_2_info['pe_path_enc_aware_vec']) dist_gt = torch.tensor([edit_distance_normalization(arch_path_encoding_pair1[i], arch_path_encoding_pair2[i], self.node_num) for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32) batch1 = Batch.from_data_list(data_list_pair1) batch1 = batch1.to(self.device) batch2 = Batch.from_data_list(data_list_pair2) batch2 = batch2.to(self.device) dist_gt = dist_gt.to(self.device) batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2, batch_edge_idx_2, batch_idx_2) prediction = prediction.squeeze(dim=-1) loss = self.criterion(prediction, dist_gt) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() meters.update(loss=loss.item()) save_dir = os.path.join(self.save_dir, f'unsupervised_ss_rl_epoch_{epoch}.pt') if self.save_model: torch.save(self.predictor.state_dict(), save_dir) if self.logger: self.logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))])) else: print(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
def fit(self, edge_index, node_feature, edge_index_reverse, node_feature_reverse, val_accuracy): meters = MetricLogger(delimiter=" ") self.stage1.train() for epoch in range(self.epochs): idx_list = list(range(len(edge_index))) random.shuffle(idx_list) batch_idx_list = gen_batch_idx(idx_list, 10) for i, batch_idx in enumerate(batch_idx_list): data_list = [] data_list_reverse = [] target_list = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) target_list.append(val_accuracy[idx]) val_tensor = torch.tensor(target_list, dtype=torch.float32) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) val_tensor = val_tensor.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred = self.stage1(batch_nodes, batch_edge_idx, batch_idx, batch_nodes_reverse, batch_edge_idx_reverse) pred = pred.squeeze() loss = self.criterion(pred, val_tensor) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step(epoch + int(i / 30)) meters.update(loss=loss.item()) return meters.meters['loss'].avg
def pred(self, edge_index, node_feature): pred_list = [] mean_list = [] std_list = [] idx_list = list(range(len(edge_index))) self.nas_agent.eval() batch_idx_list = gen_batch_idx(idx_list, 64) with torch.no_grad(): for batch_idx in batch_idx_list: data_list = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch pred, mean, std = self.nas_agent(batch_nodes, batch_edge_idx, batch_idx) pred = pred.squeeze() mean = mean.squeeze() std = std.squeeze() if len(pred.size()) == 0: pred.unsqueeze_(0) mean.unsqueeze_(0) std.unsqueeze_(0) pred_list.append(pred) mean_list.append(mean) std_list.append(std) return torch.cat(pred_list, dim=0), torch.cat(mean_list, dim=0), torch.cat(std_list, dim=0)
def fit(self, edge_index, node_feature, val_accuracy, logger=None): meters = MetricLogger(delimiter=" ") self.nas_agent.train() for epoch in range(self.epoch): idx_list = list(range(len(edge_index))) random.shuffle(idx_list) batch_idx_list = gen_batch_idx(idx_list, self.batch_size) counter = 0 for batch_idx in batch_idx_list: counter += len(batch_idx) data_list = [] target_list = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) target_list.append(val_accuracy[idx]) val_tensor = torch.tensor(target_list, dtype=torch.float32) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) val_tensor = val_tensor.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch # batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) pred, mean, std = self.nas_agent(batch_nodes, batch_edge_idx, batch_idx) val_tensor = val_tensor.unsqueeze(dim=1) loss = self.criterion(mean, std, val_tensor) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() meters.update(loss=loss.item()) if logger: logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))])) return meters.meters['loss'].avg
def fit_train_g_data(self, g_data, accuracy, logger=None): meters = MetricLogger(delimiter=" ") self.predictor.train() for epoch in range(self.epoch): idx_list = list(range(len(g_data))) random.shuffle(idx_list) batch_idx_list = gen_batch_idx(idx_list, self.batch_size) counter = 0 for i, batch_idx in enumerate(batch_idx_list): counter += len(batch_idx) data_list = [g_data[id] for id in batch_idx] target_list = [accuracy[id] for id in batch_idx] val_tensor = torch.tensor(target_list, dtype=torch.float32) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) val_tensor = val_tensor.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch # batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) pred = self.predictor(batch_nodes, batch_edge_idx, batch_idx) pred = pred.squeeze() loss = self.criterion(pred, val_tensor) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() meters.update(loss=loss.item()) save_dir = os.path.join(self.save_dir, f'supervised_gin_epoch_{epoch}.pt') if self.save_model: torch.save(self.predictor.state_dict(), save_dir) return meters.meters['loss'].avg
def inference(self, edge_index, node_feature, arch_encoding, batch_idx_test_1, batch_idx_test_2): self.predictor.eval() error_list = [] precision_list = [] for i, pair1_idx in enumerate(batch_idx_test_1): pair2_idx = batch_idx_test_2[i] data_list_pair1 = [] arch_path_encoding_pair1 = [] data_list_pair2 = [] arch_path_encoding_pair2 = [] for pair_idx in zip(pair1_idx, pair2_idx): idx1, idx2 = pair_idx g_d_1 = Data(edge_index=edge_index[idx1].long(), x=node_feature[idx1].float()) data_list_pair1.append(g_d_1) arch_path_encoding_pair1.append(arch_encoding[idx1]) g_d_2 = Data(edge_index=edge_index[idx2].long(), x=node_feature[idx2].float()) data_list_pair2.append(g_d_2) arch_path_encoding_pair2.append(arch_encoding[idx2]) dist_gt = torch.tensor([edit_distance(arch_path_encoding_pair1[i], arch_path_encoding_pair2[i]) for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32) batch1 = Batch.from_data_list(data_list_pair1) batch1 = batch1.to(self.device) batch2 = Batch.from_data_list(data_list_pair2) batch2 = batch2.to(self.device) dist_gt = dist_gt.to(self.device) batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2, batch_edge_idx_2, batch_idx_2) prediction = -1 * torch.log(prediction.squeeze(dim=-1)) * self.node_num errors = torch.abs(dist_gt - prediction) precision = (torch.sum(errors < 1) * 1.) / errors.size(0) error_list.append(torch.mean(errors).item()) precision_list.append(precision.item()) if self.logger: self.logger.info(f'Error is {np.mean(np.array(error_list))}, Precision is {np.mean(np.array(precision_list))}') else: print(f'Error is {np.mean(np.array(error_list))}, Precision is {np.mean(np.array(precision_list))}')
def pred_g_data(self, g_data): pred_list = [] idx_list = list(range(len(g_data))) self.predictor.eval() batch_idx_list = gen_batch_idx(idx_list, 64) with torch.no_grad(): for batch_idx in batch_idx_list: data_list = [g_data[idx] for idx in batch_idx] batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch pred = self.predictor(batch_nodes, batch_edge_idx, batch_idx).squeeze() if len(pred.size()) == 0: pred.unsqueeze_(0) pred_list.append(pred) return torch.cat(pred_list, dim=0)
def fit(self, edge_index, node_feature, edge_index_reverse, node_feature_reverse, val_accuracy, val_accuracy_cls): meters_cls = MetricLogger(delimiter=" ") meters_regerss = MetricLogger(delimiter=" ") self.stage1.train() self.stage2.train() for epoch in range(self.epochs): idx_list = list(range(len(edge_index))) random.shuffle(idx_list) batch_idx_list = gen_batch_idx(idx_list, 10) for i, batch_idx in enumerate(batch_idx_list): data_list = [] data_list_reverse = [] target_list_cls = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) target_list_cls.append(val_accuracy_cls[idx]) val_cls_tensor = torch.tensor(target_list_cls, dtype=torch.long) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) val_cls_tensor = val_cls_tensor.to(self.device) batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred_cls = self.stage1(batch_nodes, batch_edge_idx, batch_idx_g, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() loss_stage1 = self.criterion_ce(pred_cls, val_cls_tensor) self.optimizer_cls.zero_grad() loss_stage1.backward() self.optimizer_cls.step() self.scheduler_cls.step(epoch + int(i / 30)) meters_cls.update(loss=loss_stage1.item()) if pred_cls.dim() == 1: pred_cls.unsequeeze_(dim=0) pred_max = torch.argmax(pred_cls, dim=1) if torch.sum(pred_max) == 0: continue data_list = [] data_list_reverse = [] target_list = [] for k, idx in enumerate(batch_idx): if pred_max[k] == 0: continue g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) target_list.append(val_accuracy[idx]) val_tensor = torch.tensor(target_list, dtype=torch.float32) val_tensor = val_tensor.to(self.device) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx_G = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred_regress = self.stage2(batch_nodes, batch_edge_idx, batch_idx_G, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() if pred_regress.dim() == 0: loss_stage2 = self.criterion(pred_regress, val_tensor[0]) else: loss_stage2 = self.criterion(pred_regress, val_tensor) self.optimizer_regress.zero_grad() loss_stage2.backward() self.optimizer_regress.step() self.scheduler_regress.step(epoch + int(i / 30)) meters_regerss.update(loss=loss_stage2.item()) return meters_cls.meters['loss'].avg, meters_regerss.meters['loss'].avg
def pred(self, edge_index, node_feature, edge_index_reverse, node_feature_reverse): pred_list = [] idx_list = list(range(len(edge_index))) self.stage1.eval() self.stage2.eval() batch_idx_list = gen_batch_idx(idx_list, 32) with torch.no_grad(): for i, batch_idx in enumerate(batch_idx_list): data_list = [] data_list_reverse = [] for idx in batch_idx: g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred_cls = self.stage1(batch_nodes, batch_edge_idx, batch_idx_g, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() if pred_cls.dim() == 1: pred_cls.unsqueeze_(dim=0) pred_max = torch.argmax(pred_cls, dim=1) if pred_max.dim() == 0: pred_max.unsqueeze_(dim=0) if torch.sum(pred_max) == 0: continue data_list = [] data_list_reverse = [] for k, idx in enumerate(batch_idx): if pred_max[k] == 0: continue g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float()) data_list.append(g_d) g_d_reverse = Data( edge_index=edge_index_reverse[idx].long(), x=node_feature_reverse[idx].float()) data_list_reverse.append(g_d_reverse) batch = Batch.from_data_list(data_list) batch = batch.to(self.device) batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch batch_nodes = F.normalize(batch_nodes, p=2, dim=-1) batch_reverse = Batch.from_data_list(data_list_reverse) batch_reverse = batch_reverse.to(self.device) batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \ batch_reverse.edge_index, \ batch_reverse.batch batch_nodes_reverse = F.normalize(batch_nodes_reverse, p=2, dim=-1) pred_regress = self.stage2(batch_nodes, batch_edge_idx, batch_idx_g, batch_nodes_reverse, batch_edge_idx_reverse).squeeze() pred = torch.zeros_like(pred_max, dtype=torch.float32) if pred_regress.dim() == 0: pred_regress.unsqueeze_(dim=0) # print(pred_max.size(), pred_regress.size(), pred.size()) counter = 0 for j in range(pred.size(0)): if pred_max[j] == 0: pred[j] = 0 else: pred[j] = pred_regress[counter] counter += 1 if len(pred.size()) == 0: pred.unsqueeze_(0) pred_list.append(pred) return torch.cat(pred_list, dim=0)
def fit_train(self, train_g_data_list, arch_encoding, epoch): meters = MetricLogger(delimiter=" ") self.predictor.train() idx_list = list(range(len(train_g_data_list))) idx_list2 = list(idx_list) random.shuffle(idx_list) random.shuffle(idx_list2) if self.args and self.args.add_corresponding: idx = list(range(len(train_g_data_list))) idx2 = list(range(len(train_g_data_list))) idx_list.extend(idx) idx_list2.extend(idx2) indices_list = list(range(len(idx_list))) random.shuffle(indices_list) idx_list = [idx_list[i] for i in indices_list] idx_list2 = [idx_list2[i] for i in indices_list] batch_idx_list_1 = gen_batch_idx_gen(idx_list, self.batch_size, drop_last=True) batch_idx_list_2 = gen_batch_idx_gen(idx_list2, self.batch_size, drop_last=True) counter = 0 for i, (pair1_idx, pair2_idx) in enumerate(zip(batch_idx_list_1, batch_idx_list_2)): counter += len(pair1_idx) data_list_pair1 = [] arch_path_encoding_pair1 = [] data_list_pair2 = [] arch_path_encoding_pair2 = [] for pair_idx in zip(pair1_idx, pair2_idx): idx1, idx2 = pair_idx g_d_1 = train_g_data_list[idx1] data_list_pair1.append(g_d_1) arch_path_encoding_pair1.append(arch_encoding[idx1]) g_d_2 = train_g_data_list[idx2] data_list_pair2.append(g_d_2) arch_path_encoding_pair2.append(arch_encoding[idx2]) if self.args.ged_type == 'normalized': dist_gt = torch.tensor([edit_distance_normalization(arch_path_encoding_pair1[i], arch_path_encoding_pair2[i], self.node_num) for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32) elif self.args.ged_type == 'wo_normalized': dist_gt = torch.tensor([edit_distance(arch_path_encoding_pair1[i], arch_path_encoding_pair2[i]) for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32) else: raise ValueError(f'The ged type {self.args.ged_type} does not support!') batch1 = Batch.from_data_list(data_list_pair1) batch1 = batch1.to(self.device) batch2 = Batch.from_data_list(data_list_pair2) batch2 = batch2.to(self.device) dist_gt = dist_gt.to(self.device) batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2, batch_edge_idx_2, batch_idx_2) prediction = prediction.squeeze(dim=-1) loss = self.criterion(prediction, dist_gt) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.scheduler.step() meters.update(loss=loss.item()) save_dir = os.path.join(self.save_dir, f'unsupervised_ss_rl_epoch_{epoch}.pt') if self.save_model: torch.save(self.predictor.state_dict(), save_dir) if self.logger: self.logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))])) else: print(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
def train_nested(train_loader, model, criterion, optimizer, epoch, args, logger): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch), logger=logger) # switch to train mode model.train() device = torch.device(f'cuda:{args.gpu}') end = time.time() center_list = [] step = args.batch_step for i, (g_d, path_encodings) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) batch = Batch.from_data_list(g_d) batch = batch.to(device) indices = list(range(len(g_d))) random.shuffle(indices) indices = indices[:args.train_samples] indices_list = [ indices[i * step:(i + 1) * step] for i in range(args.train_samples // step) ] for idxss, sample_ids in enumerate(indices_list): # compute output logits, label, centers = model(batch=batch, path_encoding=path_encodings, device=device, search_space=args.search_space, sample_ids=sample_ids, logger=logger) loss = criterion(logits, label) if args.center_regularization: center_dist = torch.mm(centers, centers.T) masks = torch.ones_like(center_dist) eigen_val = list(range(center_dist.size(0))) masks[eigen_val, eigen_val] = 0 center_loss = 0.5 * torch.mean(masks * center_dist) loss = loss + 0.5 * center_loss size_logits = logits.size(0) # acc1/acc5 are (K+1)-way contrast classifier accuracy # measure accuracy and record loss acc1, acc5 = accuracy(logits, label, topk=(1, 5)) losses.update(loss.item(), size_logits) top1.update(acc1[0], size_logits) top5.update(acc5[0], size_logits) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) center_list.append(centers.cpu().detach().numpy()) return center_list