Ejemplo n.º 1
0
    def pred(self, g_data_list, g_data_list_reverse):
        pred_list = []
        idx_list = list(range(len(g_data_list)))
        self.stage1.eval()
        batch_idx_list = gen_batch_idx(idx_list, 32)
        with torch.no_grad():
            for batch_idx in batch_idx_list:
                data_list = [g_data_list[idx] for idx in batch_idx]
                data_list_reverse = [
                    g_data_list_reverse[idx] for idx in batch_idx
                ]

                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)

                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)

                pred = self.stage1(batch_nodes, batch_edge_idx, batch_idx,
                                   batch_nodes_reverse,
                                   batch_edge_idx_reverse).squeeze()
                if len(pred.size()) == 0:
                    pred.unsqueeze_(0)
                pred_list.append(pred)
        return torch.cat(pred_list, dim=0)
Ejemplo n.º 2
0
    def fit_full_train(self, train_g_data_list, arch_encoding, epoch):
        meters = MetricLogger(delimiter=" ")
        self.predictor.train()
        idx_list = list(range(len(train_g_data_list)))
        idx_dataset = ProductList(idx_list)
        training_data = torch.utils.data.DataLoader(idx_dataset, batch_size=self.batch_size,
                                                    shuffle=True)
        counter = 0
        for i, v in enumerate(training_data):
            pair1_idx = v[0].cpu().numpy()
            pair2_idx = v[1].cpu().numpy()

            counter += len(pair1_idx)
            data_list_pair1 = []
            arch_path_encoding_pair1 = []

            data_list_pair2 = []
            arch_path_encoding_pair2 = []

            for pair_idx in zip(pair1_idx, pair2_idx):
                idx1, idx2 = pair_idx
                g_d_1 = train_g_data_list[idx1]
                data_list_pair1.append(g_d_1)
                arch_path_encoding_pair1.append(arch_encoding[idx1])

                g_d_2 = train_g_data_list[idx2]
                data_list_pair2.append(g_d_2)
                arch_path_encoding_pair2.append(arch_encoding[idx2])

            dist_gt = torch.tensor([edit_distance_normalization(arch_path_encoding_pair1[i],
                                                                arch_path_encoding_pair2[i], self.node_num)
                                    for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32)
            batch1 = Batch.from_data_list(data_list_pair1)
            batch1 = batch1.to(self.device)

            batch2 = Batch.from_data_list(data_list_pair2)
            batch2 = batch2.to(self.device)

            dist_gt = dist_gt.to(self.device)

            batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch
            batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch
            prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2,
                                        batch_edge_idx_2, batch_idx_2)
            prediction = prediction.squeeze(dim=-1)
            loss = self.criterion(prediction, dist_gt)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            meters.update(loss=loss.item())
        save_dir = os.path.join(self.save_dir, f'unsupervised_ss_rl_epoch_{epoch}.pt')
        if self.save_model:
            torch.save(self.predictor.state_dict(), save_dir)
        if self.logger:
            self.logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
        else:
            print(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
Ejemplo n.º 3
0
    def fit_train(self, total_archs, epoch):
        meters = MetricLogger(delimiter=" ")
        self.predictor.train()
        idx_list = list(range(len(total_archs)))
        idx_list2 = list(idx_list)
        random.shuffle(idx_list)
        batch_idx_list_1 = gen_batch_idx_gen(idx_list, self.batch_size, drop_last=True)
        random.shuffle(idx_list2)
        batch_idx_list_2 = gen_batch_idx_gen(idx_list2, self.batch_size, drop_last=True)
        counter = 0
        for i, (pair1_idx, pair2_idx) in enumerate(zip(batch_idx_list_1, batch_idx_list_2)):
            counter += len(pair1_idx)
            data_list_pair1 = []
            arch_path_encoding_pair1 = []

            data_list_pair2 = []
            arch_path_encoding_pair2 = []

            for pair_idx in zip(pair1_idx, pair2_idx):
                idx1, idx2 = pair_idx
                arch_1_info = gen_arch_info(total_archs[idx1])
                g_d_1 = arch_1_info['g_data']
                data_list_pair1.append(g_d_1)
                arch_path_encoding_pair1.append(arch_1_info['pe_path_enc_aware_vec'])

                arch_2_info = gen_arch_info(total_archs[idx2])
                data_list_pair2.append(arch_2_info['g_data'])
                arch_path_encoding_pair2.append(arch_2_info['pe_path_enc_aware_vec'])

            dist_gt = torch.tensor([edit_distance_normalization(arch_path_encoding_pair1[i],
                                                                arch_path_encoding_pair2[i], self.node_num)
                                    for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32)
            batch1 = Batch.from_data_list(data_list_pair1)
            batch1 = batch1.to(self.device)

            batch2 = Batch.from_data_list(data_list_pair2)
            batch2 = batch2.to(self.device)

            dist_gt = dist_gt.to(self.device)

            batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch
            batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch
            prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2,
                                        batch_edge_idx_2, batch_idx_2)
            prediction = prediction.squeeze(dim=-1)
            loss = self.criterion(prediction, dist_gt)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            meters.update(loss=loss.item())
        save_dir = os.path.join(self.save_dir, f'unsupervised_ss_rl_epoch_{epoch}.pt')
        if self.save_model:
            torch.save(self.predictor.state_dict(), save_dir)
        if self.logger:
            self.logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
        else:
            print(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
Ejemplo n.º 4
0
    def fit(self, edge_index, node_feature, edge_index_reverse,
            node_feature_reverse, val_accuracy):
        meters = MetricLogger(delimiter="  ")
        self.stage1.train()
        for epoch in range(self.epochs):
            idx_list = list(range(len(edge_index)))
            random.shuffle(idx_list)
            batch_idx_list = gen_batch_idx(idx_list, 10)
            for i, batch_idx in enumerate(batch_idx_list):
                data_list = []
                data_list_reverse = []
                target_list = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)
                    target_list.append(val_accuracy[idx])

                val_tensor = torch.tensor(target_list, dtype=torch.float32)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                val_tensor = val_tensor.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)

                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)

                pred = self.stage1(batch_nodes, batch_edge_idx, batch_idx,
                                   batch_nodes_reverse, batch_edge_idx_reverse)
                pred = pred.squeeze()
                loss = self.criterion(pred, val_tensor)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step(epoch + int(i / 30))
                meters.update(loss=loss.item())
        return meters.meters['loss'].avg
    def pred(self, edge_index, node_feature):
        pred_list = []
        mean_list = []
        std_list = []
        idx_list = list(range(len(edge_index)))
        self.nas_agent.eval()
        batch_idx_list = gen_batch_idx(idx_list, 64)
        with torch.no_grad():
            for batch_idx in batch_idx_list:
                data_list = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float())
                    data_list.append(g_d)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                pred, mean, std = self.nas_agent(batch_nodes, batch_edge_idx, batch_idx)

                pred = pred.squeeze()
                mean = mean.squeeze()
                std = std.squeeze()
                if len(pred.size()) == 0:
                    pred.unsqueeze_(0)
                    mean.unsqueeze_(0)
                    std.unsqueeze_(0)
                pred_list.append(pred)
                mean_list.append(mean)
                std_list.append(std)
        return torch.cat(pred_list, dim=0), torch.cat(mean_list, dim=0), torch.cat(std_list, dim=0)
    def fit(self, edge_index, node_feature, val_accuracy, logger=None):
        meters = MetricLogger(delimiter=" ")
        self.nas_agent.train()
        for epoch in range(self.epoch):
            idx_list = list(range(len(edge_index)))
            random.shuffle(idx_list)
            batch_idx_list = gen_batch_idx(idx_list, self.batch_size)
            counter = 0
            for batch_idx in batch_idx_list:
                counter += len(batch_idx)
                data_list = []
                target_list = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(), x=node_feature[idx].float())
                    data_list.append(g_d)
                    target_list.append(val_accuracy[idx])
                val_tensor = torch.tensor(target_list, dtype=torch.float32)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                val_tensor = val_tensor.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                # batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)

                pred, mean, std = self.nas_agent(batch_nodes, batch_edge_idx, batch_idx)
                val_tensor = val_tensor.unsqueeze(dim=1)
                loss = self.criterion(mean, std, val_tensor)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                meters.update(loss=loss.item())
        if logger:
            logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
        return meters.meters['loss'].avg
Ejemplo n.º 7
0
    def fit_train_g_data(self, g_data, accuracy, logger=None):
        meters = MetricLogger(delimiter=" ")
        self.predictor.train()
        for epoch in range(self.epoch):
            idx_list = list(range(len(g_data)))
            random.shuffle(idx_list)
            batch_idx_list = gen_batch_idx(idx_list, self.batch_size)
            counter = 0
            for i, batch_idx in enumerate(batch_idx_list):
                counter += len(batch_idx)
                data_list = [g_data[id] for id in batch_idx]
                target_list = [accuracy[id] for id in batch_idx]
                val_tensor = torch.tensor(target_list, dtype=torch.float32)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                val_tensor = val_tensor.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
                # batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)

                pred = self.predictor(batch_nodes, batch_edge_idx, batch_idx)
                pred = pred.squeeze()
                loss = self.criterion(pred, val_tensor)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                meters.update(loss=loss.item())
            save_dir = os.path.join(self.save_dir, f'supervised_gin_epoch_{epoch}.pt')
            if self.save_model:
                torch.save(self.predictor.state_dict(), save_dir)
        return meters.meters['loss'].avg
Ejemplo n.º 8
0
    def inference(self, edge_index, node_feature, arch_encoding, batch_idx_test_1, batch_idx_test_2):
        self.predictor.eval()
        error_list = []
        precision_list = []
        for i, pair1_idx in enumerate(batch_idx_test_1):
            pair2_idx = batch_idx_test_2[i]
            data_list_pair1 = []
            arch_path_encoding_pair1 = []

            data_list_pair2 = []
            arch_path_encoding_pair2 = []

            for pair_idx in zip(pair1_idx, pair2_idx):
                idx1, idx2 = pair_idx
                g_d_1 = Data(edge_index=edge_index[idx1].long(), x=node_feature[idx1].float())
                data_list_pair1.append(g_d_1)
                arch_path_encoding_pair1.append(arch_encoding[idx1])

                g_d_2 = Data(edge_index=edge_index[idx2].long(), x=node_feature[idx2].float())
                data_list_pair2.append(g_d_2)
                arch_path_encoding_pair2.append(arch_encoding[idx2])

            dist_gt = torch.tensor([edit_distance(arch_path_encoding_pair1[i], arch_path_encoding_pair2[i])
                                    for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32)
            batch1 = Batch.from_data_list(data_list_pair1)
            batch1 = batch1.to(self.device)

            batch2 = Batch.from_data_list(data_list_pair2)
            batch2 = batch2.to(self.device)

            dist_gt = dist_gt.to(self.device)

            batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch
            batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch
            prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2,
                                        batch_edge_idx_2, batch_idx_2)
            prediction = -1 * torch.log(prediction.squeeze(dim=-1)) * self.node_num

            errors = torch.abs(dist_gt - prediction)
            precision = (torch.sum(errors < 1) * 1.) / errors.size(0)
            error_list.append(torch.mean(errors).item())
            precision_list.append(precision.item())
        if self.logger:
            self.logger.info(f'Error is {np.mean(np.array(error_list))}, Precision is {np.mean(np.array(precision_list))}')
        else:
            print(f'Error is {np.mean(np.array(error_list))}, Precision is {np.mean(np.array(precision_list))}')
Ejemplo n.º 9
0
 def pred_g_data(self, g_data):
     pred_list = []
     idx_list = list(range(len(g_data)))
     self.predictor.eval()
     batch_idx_list = gen_batch_idx(idx_list, 64)
     with torch.no_grad():
         for batch_idx in batch_idx_list:
             data_list = [g_data[idx] for idx in batch_idx]
             batch = Batch.from_data_list(data_list)
             batch = batch.to(self.device)
             batch_nodes, batch_edge_idx, batch_idx = batch.x, batch.edge_index, batch.batch
             pred = self.predictor(batch_nodes, batch_edge_idx, batch_idx).squeeze()
             if len(pred.size()) == 0:
                 pred.unsqueeze_(0)
             pred_list.append(pred)
     return torch.cat(pred_list, dim=0)
Ejemplo n.º 10
0
    def fit(self, edge_index, node_feature, edge_index_reverse,
            node_feature_reverse, val_accuracy, val_accuracy_cls):
        meters_cls = MetricLogger(delimiter="  ")
        meters_regerss = MetricLogger(delimiter="  ")
        self.stage1.train()
        self.stage2.train()
        for epoch in range(self.epochs):
            idx_list = list(range(len(edge_index)))
            random.shuffle(idx_list)
            batch_idx_list = gen_batch_idx(idx_list, 10)

            for i, batch_idx in enumerate(batch_idx_list):
                data_list = []
                data_list_reverse = []
                target_list_cls = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)
                    target_list_cls.append(val_accuracy_cls[idx])

                val_cls_tensor = torch.tensor(target_list_cls,
                                              dtype=torch.long)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                val_cls_tensor = val_cls_tensor.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)
                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)

                pred_cls = self.stage1(batch_nodes, batch_edge_idx,
                                       batch_idx_g, batch_nodes_reverse,
                                       batch_edge_idx_reverse).squeeze()
                loss_stage1 = self.criterion_ce(pred_cls, val_cls_tensor)
                self.optimizer_cls.zero_grad()
                loss_stage1.backward()
                self.optimizer_cls.step()
                self.scheduler_cls.step(epoch + int(i / 30))
                meters_cls.update(loss=loss_stage1.item())

                if pred_cls.dim() == 1:
                    pred_cls.unsequeeze_(dim=0)

                pred_max = torch.argmax(pred_cls, dim=1)
                if torch.sum(pred_max) == 0:
                    continue
                data_list = []
                data_list_reverse = []
                target_list = []
                for k, idx in enumerate(batch_idx):
                    if pred_max[k] == 0:
                        continue
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)
                    target_list.append(val_accuracy[idx])
                val_tensor = torch.tensor(target_list, dtype=torch.float32)
                val_tensor = val_tensor.to(self.device)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx_G = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)
                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)
                pred_regress = self.stage2(batch_nodes, batch_edge_idx,
                                           batch_idx_G, batch_nodes_reverse,
                                           batch_edge_idx_reverse).squeeze()
                if pred_regress.dim() == 0:
                    loss_stage2 = self.criterion(pred_regress, val_tensor[0])
                else:
                    loss_stage2 = self.criterion(pred_regress, val_tensor)
                self.optimizer_regress.zero_grad()
                loss_stage2.backward()
                self.optimizer_regress.step()
                self.scheduler_regress.step(epoch + int(i / 30))
                meters_regerss.update(loss=loss_stage2.item())
        return meters_cls.meters['loss'].avg, meters_regerss.meters['loss'].avg
Ejemplo n.º 11
0
    def pred(self, edge_index, node_feature, edge_index_reverse,
             node_feature_reverse):
        pred_list = []
        idx_list = list(range(len(edge_index)))
        self.stage1.eval()
        self.stage2.eval()
        batch_idx_list = gen_batch_idx(idx_list, 32)
        with torch.no_grad():
            for i, batch_idx in enumerate(batch_idx_list):
                data_list = []
                data_list_reverse = []
                for idx in batch_idx:
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)
                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)
                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)
                pred_cls = self.stage1(batch_nodes, batch_edge_idx,
                                       batch_idx_g, batch_nodes_reverse,
                                       batch_edge_idx_reverse).squeeze()
                if pred_cls.dim() == 1:
                    pred_cls.unsqueeze_(dim=0)
                pred_max = torch.argmax(pred_cls, dim=1)
                if pred_max.dim() == 0:
                    pred_max.unsqueeze_(dim=0)
                if torch.sum(pred_max) == 0:
                    continue
                data_list = []
                data_list_reverse = []
                for k, idx in enumerate(batch_idx):
                    if pred_max[k] == 0:
                        continue
                    g_d = Data(edge_index=edge_index[idx].long(),
                               x=node_feature[idx].float())
                    data_list.append(g_d)
                    g_d_reverse = Data(
                        edge_index=edge_index_reverse[idx].long(),
                        x=node_feature_reverse[idx].float())
                    data_list_reverse.append(g_d_reverse)

                batch = Batch.from_data_list(data_list)
                batch = batch.to(self.device)
                batch_nodes, batch_edge_idx, batch_idx_g = batch.x, batch.edge_index, batch.batch
                batch_nodes = F.normalize(batch_nodes, p=2, dim=-1)
                batch_reverse = Batch.from_data_list(data_list_reverse)
                batch_reverse = batch_reverse.to(self.device)
                batch_nodes_reverse, batch_edge_idx_reverse, batch_idx_reverse = batch_reverse.x, \
                                                                                 batch_reverse.edge_index, \
                                                                                 batch_reverse.batch
                batch_nodes_reverse = F.normalize(batch_nodes_reverse,
                                                  p=2,
                                                  dim=-1)
                pred_regress = self.stage2(batch_nodes, batch_edge_idx,
                                           batch_idx_g, batch_nodes_reverse,
                                           batch_edge_idx_reverse).squeeze()
                pred = torch.zeros_like(pred_max, dtype=torch.float32)
                if pred_regress.dim() == 0:
                    pred_regress.unsqueeze_(dim=0)
                # print(pred_max.size(), pred_regress.size(), pred.size())
                counter = 0
                for j in range(pred.size(0)):
                    if pred_max[j] == 0:
                        pred[j] = 0
                    else:
                        pred[j] = pred_regress[counter]
                        counter += 1
                if len(pred.size()) == 0:
                    pred.unsqueeze_(0)
                pred_list.append(pred)
        return torch.cat(pred_list, dim=0)
Ejemplo n.º 12
0
    def fit_train(self, train_g_data_list, arch_encoding, epoch):
        meters = MetricLogger(delimiter=" ")
        self.predictor.train()
        idx_list = list(range(len(train_g_data_list)))
        idx_list2 = list(idx_list)
        random.shuffle(idx_list)
        random.shuffle(idx_list2)
        if self.args and self.args.add_corresponding:
            idx = list(range(len(train_g_data_list)))
            idx2 = list(range(len(train_g_data_list)))
            idx_list.extend(idx)
            idx_list2.extend(idx2)
            indices_list = list(range(len(idx_list)))
            random.shuffle(indices_list)
            idx_list = [idx_list[i] for i in indices_list]
            idx_list2 = [idx_list2[i] for i in indices_list]
        batch_idx_list_1 = gen_batch_idx_gen(idx_list, self.batch_size, drop_last=True)
        batch_idx_list_2 = gen_batch_idx_gen(idx_list2, self.batch_size, drop_last=True)
        counter = 0
        for i, (pair1_idx, pair2_idx) in enumerate(zip(batch_idx_list_1, batch_idx_list_2)):
            counter += len(pair1_idx)
            data_list_pair1 = []
            arch_path_encoding_pair1 = []

            data_list_pair2 = []
            arch_path_encoding_pair2 = []

            for pair_idx in zip(pair1_idx, pair2_idx):
                idx1, idx2 = pair_idx
                g_d_1 = train_g_data_list[idx1]
                data_list_pair1.append(g_d_1)
                arch_path_encoding_pair1.append(arch_encoding[idx1])

                g_d_2 = train_g_data_list[idx2]
                data_list_pair2.append(g_d_2)
                arch_path_encoding_pair2.append(arch_encoding[idx2])

            if self.args.ged_type == 'normalized':
                dist_gt = torch.tensor([edit_distance_normalization(arch_path_encoding_pair1[i],
                                                                    arch_path_encoding_pair2[i], self.node_num)
                                        for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32)
            elif self.args.ged_type == 'wo_normalized':
                dist_gt = torch.tensor([edit_distance(arch_path_encoding_pair1[i],
                                                      arch_path_encoding_pair2[i])
                                        for i in range(len(arch_path_encoding_pair1))], dtype=torch.float32)
            else:
                raise ValueError(f'The ged type {self.args.ged_type} does not support!')
            batch1 = Batch.from_data_list(data_list_pair1)
            batch1 = batch1.to(self.device)

            batch2 = Batch.from_data_list(data_list_pair2)
            batch2 = batch2.to(self.device)

            dist_gt = dist_gt.to(self.device)

            batch_nodes_1, batch_edge_idx_1, batch_idx_1 = batch1.x, batch1.edge_index, batch1.batch
            batch_nodes_2, batch_edge_idx_2, batch_idx_2 = batch2.x, batch2.edge_index, batch2.batch
            prediction = self.predictor(batch_nodes_1, batch_edge_idx_1, batch_idx_1, batch_nodes_2,
                                        batch_edge_idx_2, batch_idx_2)
            prediction = prediction.squeeze(dim=-1)
            loss = self.criterion(prediction, dist_gt)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            meters.update(loss=loss.item())
        save_dir = os.path.join(self.save_dir, f'unsupervised_ss_rl_epoch_{epoch}.pt')
        if self.save_model:
            torch.save(self.predictor.state_dict(), save_dir)
        if self.logger:
            self.logger.info(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
        else:
            print(meters.delimiter.join(['{loss}'.format(loss=str(meters))]))
Ejemplo n.º 13
0
def train_nested(train_loader, model, criterion, optimizer, epoch, args,
                 logger):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch),
                             logger=logger)

    # switch to train mode
    model.train()
    device = torch.device(f'cuda:{args.gpu}')
    end = time.time()
    center_list = []
    step = args.batch_step
    for i, (g_d, path_encodings) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        batch = Batch.from_data_list(g_d)
        batch = batch.to(device)

        indices = list(range(len(g_d)))
        random.shuffle(indices)
        indices = indices[:args.train_samples]
        indices_list = [
            indices[i * step:(i + 1) * step]
            for i in range(args.train_samples // step)
        ]
        for idxss, sample_ids in enumerate(indices_list):
            # compute output
            logits, label, centers = model(batch=batch,
                                           path_encoding=path_encodings,
                                           device=device,
                                           search_space=args.search_space,
                                           sample_ids=sample_ids,
                                           logger=logger)
            loss = criterion(logits, label)
            if args.center_regularization:
                center_dist = torch.mm(centers, centers.T)
                masks = torch.ones_like(center_dist)
                eigen_val = list(range(center_dist.size(0)))
                masks[eigen_val, eigen_val] = 0
                center_loss = 0.5 * torch.mean(masks * center_dist)
                loss = loss + 0.5 * center_loss
            size_logits = logits.size(0)
            # acc1/acc5 are (K+1)-way contrast classifier accuracy
            # measure accuracy and record loss
            acc1, acc5 = accuracy(logits, label, topk=(1, 5))
            losses.update(loss.item(), size_logits)
            top1.update(acc1[0], size_logits)
            top5.update(acc5[0], size_logits)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
        center_list.append(centers.cpu().detach().numpy())
    return center_list