Exemplo n.º 1
0
 def get_cross_validation_loader(self, fold):
     (train_idx, test_idx, val_idx
      ) = self.train_ids[fold], self.test_ids[fold], self.val_ids[fold]
     train_dataset = self.dataset[train_idx]
     test_dataset = self.dataset[test_idx]
     val_dataset = self.dataset[val_idx]
     if 'adj' in train_dataset[0]:
         train_loader = DenseDataLoader(train_dataset,
                                        self.batch_size,
                                        shuffle=True)
         val_loader = DenseDataLoader(val_dataset,
                                      self.batch_size,
                                      shuffle=False)
         test_loader = DenseDataLoader(test_dataset,
                                       self.batch_size,
                                       shuffle=False)
     else:
         train_loader = DataLoader(train_dataset,
                                   self.batch_size,
                                   shuffle=True)
         val_loader = DataLoader(val_dataset,
                                 self.batch_size,
                                 shuffle=False)
         test_loader = DataLoader(test_dataset,
                                  self.batch_size,
                                  shuffle=False)
     return train_loader, val_loader, test_loader
Exemplo n.º 2
0
def get_dataloaders(parser_args):
    """Creates the datasets and the corresponding dataloaders

    Args:
        parser_args (dict): parsed arguments

    Returns:
        (:obj:`torch.utils.data.dataloader`, :obj:`torch.utils.data.dataloader`): train, validation dataloaders
    """

    path_to_data = parser_args.path_to_data
    partition = parser_args.partition
    seed = parser_args.seed
    means_path = parser_args.means_path
    stds_path = parser_args.stds_path

    data = ARTCDataset(path_to_data)
    train_indices, temp = train_test_split(data.idxs,
                                           train_size=partition[0],
                                           random_state=seed)
    val_indices, _ = train_test_split(temp,
                                      test_size=partition[2] /
                                      (partition[1] + partition[2]),
                                      random_state=seed)

    if (means_path is None) or (stds_path is None):
        train_set_stats = ARTCDataset(path_to_data, indices=train_indices)
        means, stds = stats_extractor(train_set_stats)
        np.save("./means.npy", means)
        np.save("./stds.npy", stds)
    else:
        try:
            means = np.load(means_path)
            stds = np.load(stds_path)
        except ValueError:
            print("No means or stds were provided. Or path names incorrect.")

    train_set = ARTCDataset(path_to_data,
                            indices=train_indices,
                            transform=Normalize(means, stds))
    validation_set = ARTCDataset(path_to_data,
                                 indices=val_indices,
                                 transform=Normalize(means, stds))

    dataloader_train = DenseDataLoader(train_set,
                                       batch_size=parser_args.batch_size,
                                       shuffle=True,
                                       num_workers=1)
    dataloader_validation = DenseDataLoader(validation_set,
                                            batch_size=parser_args.batch_size,
                                            shuffle=False,
                                            num_workers=1)
    return dataloader_train, dataloader_validation
Exemplo n.º 3
0
def main():
    opt = OptInit().get_args()
    logging.info('===> Creating dataloader ...')
    train_dataset = GeoData.S3DIS(opt.data_dir,
                                  opt.area,
                                  True,
                                  pre_transform=T.NormalizeScale())
    train_loader = DenseDataLoader(train_dataset,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=4)
    test_dataset = GeoData.S3DIS(opt.data_dir,
                                 opt.area,
                                 train=False,
                                 pre_transform=T.NormalizeScale())
    test_loader = DenseDataLoader(test_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    opt.n_classes = train_loader.dataset.num_classes

    logging.info('===> Loading the network ...')
    model = DenseDeepGCN(opt).to(opt.device)
    if opt.multi_gpus:
        model = DataParallel(DenseDeepGCN(opt)).to(opt.device)
    logging.info('===> loading pre-trained ...')
    model, opt.best_value, opt.epoch = load_pretrained_models(
        model, opt.pretrained_model, opt.phase)
    logging.info(model)

    logging.info('===> Init the optimizer ...')
    criterion = torch.nn.CrossEntropyLoss().to(opt.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq,
                                                opt.lr_decay_rate)
    optimizer, scheduler, opt.lr = load_pretrained_optimizer(
        opt.pretrained_model, optimizer, scheduler, opt.lr)

    logging.info('===> Init Metric ...')
    opt.losses = AverageMeter()
    opt.test_value = 0.

    logging.info('===> start training ...')
    for _ in range(opt.epoch, opt.total_epochs):
        opt.epoch += 1
        logging.info('Epoch:{}'.format(opt.epoch))
        train(model, train_loader, optimizer, scheduler, criterion, opt)
        if opt.epoch % opt.eval_freq == 0 and opt.eval_freq != -1:
            test(model, test_loader, opt)
        scheduler.step()
    logging.info('Saving the final model.Finish!')
Exemplo n.º 4
0
    def get_loader(self):  # paras config->self.config
        dataset = self.get_dataset(self.config.dataset,
                                   sparse=self.config.sparse,
                                   dataset_div=self.config.dataset_div)
        n = (len(dataset) + 9) // 10
        test_dataset = dataset[:n]
        val_dataset = dataset[n:2 * n]
        train_dataset = dataset[2 * n:]

        train_loader = DenseDataLoader(train_dataset,
                                       batch_size=self.config.batch_size)
        val_loader = DenseDataLoader(val_dataset,
                                     batch_size=self.config.batch_size)
        test_loader = DenseDataLoader(test_dataset,
                                      batch_size=self.config.batch_size)

        return train_loader, val_loader, test_loader
Exemplo n.º 5
0
def test_enzymes():
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = TUDataset(root, 'ENZYMES')

    assert len(dataset) == 600
    assert dataset.num_features == 3
    assert dataset.num_classes == 6
    assert dataset.__repr__() == 'ENZYMES(600)'

    assert len(dataset[0]) == 3
    assert len(dataset.shuffle()) == 600
    assert len(dataset.shuffle(return_perm=True)) == 2
    assert len(dataset[:100]) == 100
    assert len(dataset[torch.arange(100, dtype=torch.long)]) == 100
    mask = torch.zeros(600, dtype=torch.bool)
    mask[:100] = 1
    assert len(dataset[mask]) == 100

    loader = DataLoader(dataset, batch_size=len(dataset))
    for data in loader:
        assert data.num_graphs == 600

        avg_num_nodes = data.num_nodes / data.num_graphs
        assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63

        avg_num_edges = data.num_edges / (2 * data.num_graphs)
        assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14

        assert len(data) == 5
        assert list(data.x.size()) == [data.num_nodes, 3]
        assert list(data.y.size()) == [data.num_graphs]
        assert data.y.max() + 1 == 6
        assert list(data.batch.size()) == [data.num_nodes]
        assert data.ptr.numel() == data.num_graphs + 1

        assert data.contains_isolated_nodes()
        assert not data.contains_self_loops()
        assert data.is_undirected()

    loader = DataListLoader(dataset, batch_size=len(dataset))
    for data_list in loader:
        assert len(data_list) == 600

    dataset.transform = ToDense(num_nodes=126)
    loader = DenseDataLoader(dataset, batch_size=len(dataset))
    for data in loader:
        assert len(data) == 4
        assert list(data.x.size()) == [600, 126, 3]
        assert list(data.adj.size()) == [600, 126, 126]
        assert list(data.mask.size()) == [600, 126]
        assert list(data.y.size()) == [600, 1]

    dataset = TUDataset(root, 'ENZYMES', use_node_attr=True)
    assert dataset.num_node_features == 21
    assert dataset.num_features == 21
    assert dataset.num_edge_features == 0

    shutil.rmtree(root)
Exemplo n.º 6
0
    def k_fold_loader_generator(self, folds):
        dataset = self.get_dataset(self.config.dataset,
                                   sparse=self.config.sparse,
                                   dataset_div=self.config.dataset_div)
        train_indices, val_indices, test_indices = self.get_k_fold_indices(
            folds, len(dataset))

        for fold, (train_idx, val_idx, test_idx) in enumerate(
                zip(train_indices, val_indices, test_indices)):

            train_loader = DenseDataLoader(dataset[train_idx],
                                           self.config.batch_size,
                                           shuffle=True)
            val_loader = DenseDataLoader(dataset[val_idx],
                                         self.config.batch_size,
                                         shuffle=False)
            test_loader = DenseDataLoader(dataset[test_idx],
                                          self.config.batch_size,
                                          shuffle=False)

            yield train_loader, val_loader, test_loader
Exemplo n.º 7
0
def main():
    opt = OptInit().initialize()
    opt.batch_size = 1
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpuNum

    print('===> Creating dataloader...')
    # def __init__(self,
    #              root,
    #              is_train=True,
    #              is_validation=False,
    #              is_test=False,
    #              num_channel=5,
    #              pre_transform=None,
    #              pre_filter=None)
    test_dataset = BigredDataset(root=opt.test_path,
                                 is_train=False,
                                 is_validation=False,
                                 is_test=True,
                                 num_channel=5,
                                 pre_transform=T.NormalizeScale())
    test_loader = DenseDataLoader(test_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=False,
                                  num_workers=32)
    opt.n_classes = 2

    print('len(test_loader):', len(test_loader))
    print('phase: ', opt.phase)
    print('batch_size: ', opt.batch_size)
    print('use_cpu: ', opt.use_cpu)
    print('gpuNum: ', opt.gpuNum)
    print('multi_gpus: ', opt.multi_gpus)
    print('test_path: ', opt.test_path)
    print('in_channels: ', opt.in_channels)
    print('device: ', opt.device)

    print('===> Loading the network ...')
    model = DenseDeepGCN(opt).to(opt.device)
    load_package = torch.load(opt.pretrained_model)
    model, opt.best_value, opt.epoch = load_pretrained_models(
        model, opt.pretrained_model, opt.phase)
    pdb.set_trace()
    for item in load_package.keys():
        if (item != 'optimizer_state_dict' and item != 'state_dict'
                and item != 'scheduler_state_dict'):
            print(str(item), load_package[item])

    print('===> Start Evaluation ...')
    test(model, test_loader, opt)
Exemplo n.º 8
0
def main():
    opt = OptInit().initialize()
    print('===> Creating dataloader ...')
    train_dataset = GeoData.S3DIS(opt.train_path,
                                  5,
                                  True,
                                  pre_transform=T.NormalizeScale())
    train_loader = DenseDataLoader(train_dataset,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=4)
    opt.n_classes = train_loader.dataset.num_classes

    print('===> Loading the network ...')
    model = getattr(models, opt.model_name)(opt).to(opt.device)
    if opt.multi_gpus:
        model = DataParallel(getattr(models,
                                     opt.model_name)(opt)).to(opt.device)
    print('===> loading pre-trained ...')
    model, opt.best_value, opt.epoch = load_pretrained_models(
        model, opt.pretrained_model, opt.phase)

    print('===> Init the optimizer ...')
    criterion = torch.nn.CrossEntropyLoss().to(opt.device)
    if opt.optim.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    elif opt.optim.lower() == 'radam':
        optimizer = optim.RAdam(model.parameters(), lr=opt.lr)
    else:
        raise NotImplementedError('opt.optim is not supported')
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq,
                                                opt.lr_decay_rate)
    optimizer, scheduler, opt.lr = load_pretrained_optimizer(
        opt.pretrained_model, optimizer, scheduler, opt.lr)

    print('===> Init Metric ...')
    opt.losses = AverageMeter()
    # opt.valid_metric = miou
    # opt.valid_values = AverageMeter()
    opt.valid_value = 0.

    print('===> start training ...')
    for _ in range(opt.total_epochs):
        opt.epoch += 1
        train(model, train_loader, optimizer, scheduler, criterion, opt)
        # valid_value = valid(model, valid_loader, valid_metric, opt)
        scheduler.step()
    print('Saving the final model.Finish!')
Exemplo n.º 9
0
def main():
    opt = OptInit().initialize()

    print('===> Creating dataloader...')
    test_dataset = GeoData.S3DIS(opt.test_path, 5, False, pre_transform=T.NormalizeScale())
    test_loader = DenseDataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=0)
    opt.n_classes = test_loader.dataset.num_classes
    if opt.no_clutter:
        opt.n_classes -= 1

    print('===> Loading the network ...')
    model = getattr(models, opt.model_name)(opt).to(opt.device)
    model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase)

    print('===> Start Evaluation ...')
    test(opt.model, test_loader, opt)
def get_s3dis_dataloaders(root_dir, phases, batch_size, category=5, augment=False):
    """
    Create Dataset and Dataloader classes of the S3DIS dataset, for
    the phases required (`train`, `test`).

    :param root_dir: Directory with the h5 files
    :param phases: List of phases. Should be from {`train`, `test`}
    :param batch_size: Batch size
    :param category: Area used for test set (1, 2, 3, 4, 5, or 6)

    :return: 2 dictionaries, each containing Dataset or Dataloader for all phases
    """
    datasets = {
        'train': S3DIS(root_dir, category, True, pre_transform=T.NormalizeScale()),
        'test': S3DIS(root_dir, category, False, pre_transform=T.NormalizeScale())
    }

    dataloaders = {x: DenseDataLoader(datasets[x], batch_size=batch_size, num_workers=4, shuffle=(x == 'train'))
                   for x in phases}
    return datasets, dataloaders, datasets['train'].num_classes
Exemplo n.º 11
0
def main():
    opt = OptInit().get_args()

    logging.info('===> Creating dataloader...')
    test_dataset = GeoData.S3DIS(opt.data_dir,
                                 opt.area,
                                 train=False,
                                 pre_transform=T.NormalizeScale())
    test_loader = DenseDataLoader(test_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    opt.n_classes = test_loader.dataset.num_classes
    if opt.no_clutter:
        opt.n_classes -= 1

    logging.info('===> Loading the network ...')
    model = DenseDeepGCN(opt).to(opt.device)
    model, opt.best_value, opt.epoch = load_pretrained_models(
        model, opt.pretrained_model, opt.phase)

    logging.info('===> Start Evaluation ...')
    test(model, test_loader, opt)
Exemplo n.º 12
0
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_value': opt.best_value,
    }
    torch.save(state, filename)
    logging.info('save a new best model into {}'.format(filename))


if __name__ == '__main__':
    opt = OptInit()._get_args()
    logging.info('===> Creating dataloader ...')

    train_dataset = PartNet(opt.data_dir, 'sem_seg_h5', opt.category,
                            opt.level, 'train')
    train_loader = DenseDataLoader(train_dataset,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=8)

    test_dataset = PartNet(opt.data_dir, 'sem_seg_h5', opt.category, opt.level,
                           'test')
    test_loader = DenseDataLoader(test_dataset,
                                  batch_size=opt.test_batch_size,
                                  shuffle=False,
                                  num_workers=8)

    val_dataset = PartNet(opt.data_dir, 'sem_seg_h5', opt.category, opt.level,
                          'val')
    val_loader = DenseDataLoader(val_dataset,
                                 batch_size=opt.test_batch_size,
                                 shuffle=False,
                                 num_workers=8)
Exemplo n.º 13
0
        return data


path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ENZYMES_d')
dataset = TUDataset(path,
                    name='ENZYMES',
                    transform=T.Compose([T.ToDense(max_nodes),
                                         MyTransform()]),
                    pre_filter=MyFilter())
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = DenseDataLoader(test_dataset, batch_size=20)
val_loader = DenseDataLoader(val_dataset, batch_size=20)
train_loader = DenseDataLoader(train_dataset, batch_size=20)


class GNN(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 hidden_channels,
                 out_channels,
                 lin=True,
                 norm=True,
                 norm_embed=True):
        super(GNN, self).__init__()

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels, norm,
Exemplo n.º 14
0
Arquivo: docs.py Projeto: yngtodd/graf
        data, slices = self.collate(data_list)

        return data

    def len(self):
        return 1

    def get(self, idx):
        data = torch.load(
            Path(self.processed_dir).joinpath(f"{self.split}.pt"))

        return data


def random_embeddings(vocab_size, embed_dim):
    """Random word embeddings"""
    x = torch.arange(0, vocab_size)
    m = nn.Embedding(vocab_size, embed_dim)
    return m(x)


if __name__ == "__main__":
    from torch_geometric.data import DenseDataLoader

    d = DocumentGraphs("/Users/ygx/data/docs", num_vocab=501)
    loader = DenseDataLoader(d, batch_size=32, shuffle=True)

    for data in loader:
        print(data)
Exemplo n.º 15
0
def main():
    opt = OptInit().get_args()
    logging.info('===> Creating dataloader ...')
    train_dataset = GeoData.S3DIS(opt.data_dir,
                                  opt.area,
                                  True,
                                  pre_transform=T.NormalizeScale())
    train_loader = DenseDataLoader(train_dataset,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=4)
    test_dataset = GeoData.S3DIS(opt.data_dir,
                                 opt.area,
                                 train=False,
                                 pre_transform=T.NormalizeScale())
    test_loader = DenseDataLoader(test_dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    opt.n_classes = train_loader.dataset.num_classes

    logging.info('===> Loading the network ...')
    model = DenseDeepGCN(opt).to(opt.device)
    if opt.multi_gpus:
        model = DataParallel(DenseDeepGCN(opt)).to(opt.device)

    logging.info('===> loading pre-trained ...')
    model, opt.best_value, opt.epoch = load_pretrained_models(
        model, opt.pretrained_model, opt.phase)
    logging.info(model)

    logging.info('===> Init the optimizer ...')
    criterion = torch.nn.CrossEntropyLoss().to(opt.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq,
                                                opt.lr_decay_rate)
    optimizer, scheduler, opt.lr = load_pretrained_optimizer(
        opt.pretrained_model, optimizer, scheduler, opt.lr)

    logging.info('===> Init Metric ...')
    opt.losses = AverageMeter()
    opt.test_value = 0.

    logging.info('===> start training ...')
    for _ in range(opt.epoch, opt.total_epochs):
        opt.epoch += 1
        logging.info('Epoch:{}'.format(opt.epoch))
        train(model, train_loader, optimizer, criterion, opt)
        if opt.epoch % opt.eval_freq == 0 and opt.eval_freq != -1:
            test(model, test_loader, opt)
        scheduler.step()

        # ------------------ save checkpoints
        # min or max. based on the metrics
        is_best = (opt.test_value < opt.best_value)
        opt.best_value = max(opt.test_value, opt.best_value)
        model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
        save_checkpoint(
            {
                'epoch': opt.epoch,
                'state_dict': model_cpu,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_value': opt.best_value,
            }, is_best, opt.ckpt_dir, opt.exp_name)

        # ------------------ tensorboard log
        info = {
            'loss': opt.losses.avg,
            'test_value': opt.test_value,
            'lr': scheduler.get_lr()[0]
        }
        opt.writer.add_scalars('epoch', info, opt.iter)

    logging.info('Saving the final model.Finish!')
def main():
    setSeed(10)
    opt = opt_global_inti()

    num_gpu = torch.cuda.device_count()
    assert num_gpu == opt.num_gpu,"opt.num_gpu NOT equals torch.cuda.device_count()" 

    gpu_name_list = []
    for i in range(num_gpu):
        gpu_name_list.append(torch.cuda.get_device_name(i))

    opt.gpu_list = gpu_name_list

    if(opt.load_pretrain!=''):
        opt,model,f_loss,optimizer,scheduler,opt_deepgcn = load_pretrained(opt)
    else:
        opt,model,f_loss,optimizer,scheduler,opt_deepgcn = creating_new_model(opt)
    


    print('----------------------Load Dataset----------------------')
    print('Root of dataset: ', opt.dataset_root)
    print('Phase: ', opt.phase)
    print('debug: ', opt.debug)

    #pdb.set_trace()

    if(opt.model!='deepgcn'):
        train_dataset = BigredDataSet(
            root=opt.dataset_root,
            is_train=True,
            is_validation=False,
            is_test=False,
            num_channel = opt.num_channel,
            test_code = opt.debug,
            including_ring = opt.including_ring
            )

        f_loss.load_weight(train_dataset.labelweights)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            pin_memory=True,
            drop_last=True,
            num_workers=int(opt.num_workers))

        validation_dataset = BigredDataSet(
            root=opt.dataset_root,
            is_train=False,
            is_validation=True,
            is_test=False,
            num_channel = opt.num_channel,
            test_code = opt.debug,
            including_ring = opt.including_ring)
        validation_loader = torch.utils.data.DataLoader(
            validation_dataset,
            batch_size=opt.batch_size,
            shuffle=False,
            pin_memory=True,
            drop_last=True,
            num_workers=int(opt.num_workers))
    else:
        train_dataset = BigredDataSetPTG(root = opt.dataset_root,
                                 is_train=True,
                                 is_validation=False,
                                 is_test=False,
                                 num_channel=opt.num_channel,
                                 new_dataset = False,
                                 test_code = opt.debug,
                                 pre_transform=torch_geometric.transforms.NormalizeScale()
                                 )
        train_loader = DenseDataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
        validation_dataset = BigredDataSetPTG(root = opt.dataset_root,
                                    is_train=False,
                                    is_validation=True,
                                    is_test=False,
                                    new_dataset = False,
                                    test_code = opt.debug,
                                    num_channel=opt.num_channel,
                                    pre_transform=torch_geometric.transforms.NormalizeScale()
                                    )
        validation_loader = DenseDataLoader(validation_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers)

        labelweights = np.zeros(2)
        labelweights, _ = np.histogram(train_dataset.data.y.numpy(), range(3))
        labelweights = labelweights.astype(np.float32)
        labelweights = labelweights / np.sum(labelweights)
        labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
        weights = torch.Tensor(labelweights).cuda()
        f_loss.load_weight(weights)







    print('train dataset num_frame: ',len(train_dataset))
    print('num_batch: ', int(len(train_loader) / opt.batch_size))


    print('validation dataset num_frame: ',len(validation_dataset))
    print('num_batch: ', int(len(validation_loader) / opt.batch_size))

    print('Batch_size: ', opt.batch_size)

    print('----------------------Prepareing Training----------------------')
    metrics_list = ['Miou','Biou','Fiou','loss','OA','time_complexicity','storage_complexicity']
    manager_test = metrics_manager(metrics_list)

    metrics_list_train = ['Miou','Biou',
                            'Fiou','loss',
                            'storage_complexicity',
                            'time_complexicity']
    manager_train = metrics_manager(metrics_list_train)


    wandb.init(project=opt.wd_project,name=opt.model_name,resume=False)
    if(opt.wandb_history == False):
        best_value = 0
    else:
        temp = wandb.restore('best_model.pth',run_path = opt.wandb_id)
        best_value = torch.load(temp.name)['Miou_validation_ave']

    wandb.config.update(opt)

    if opt.epoch_ckpt == 0:
        opt.unsave_epoch = 0
    else:
        opt.epoch_ckpt = opt.epoch_ckpt+1

    for epoch in range(opt.epoch_ckpt,opt.epoch_max):
        manager_train.reset()
        model.train()
        tic_epoch = time.perf_counter()
        print('---------------------Training----------------------')
        print("Epoch: ",epoch)
        for i, data in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9):
            
            if(opt.model == 'deepgcn'):
                points = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
                points = points[:, :opt.num_channel, :, :]
                target = data.y.cuda()
            else:
                points, target = data
                #target.shape [B,N]
                #points.shape [B,N,C]
                points, target = points.cuda(non_blocking=True), target.cuda(non_blocking=True)

            # pdb.set_trace()
            #training...
            optimizer.zero_grad()
            tic = time.perf_counter()
            pred_mics = model(points)                
            toc = time.perf_counter()
            #compute loss

            #For loss
            #target.shape [B,N] ->[B*N]
            #pred.shape [B,N,2]->[B*N,2]
            #pdb.set_trace()
            

            #pdb.set_trace()
            loss = f_loss(pred_mics, target)   

            if(opt.apex):
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            #pred.shape [B,N,2] since pred returned pass F.log_softmax
            pred, target = pred_mics[0].cpu(), target.cpu()

            #pred:[B,N,2]->[B,N]
            #pdb.set_trace()
            pred = pred.data.max(dim=2)[1]
            
            #compute iou
            Biou,Fiou = mean_iou(pred,target,num_classes =2).mean(dim=0)
            miou = (Biou+Fiou)/2

            #compute Training time complexity
            time_complexity = toc - tic


            #compute Training storage complexsity
            num_device = torch.cuda.device_count()
            assert num_device == opt.num_gpu,"opt.num_gpu NOT equals torch.cuda.device_count()" 
            temp = []
            for k in range(num_device):
                temp.append(torch.cuda.memory_allocated(k))
            RAM_usagePeak = torch.tensor(temp).float().mean()

            #print(loss.item())
            #print(miou.item())
            #writeup logger
            manager_train.update('loss',loss.item())
            manager_train.update('Biou',Biou.item())
            manager_train.update('Fiou',Fiou.item())
            manager_train.update('Miou',miou.item())
            manager_train.update('time_complexicity',float(1/time_complexity))
            manager_train.update('storage_complexicity',RAM_usagePeak.item())

            log_dict = {'loss_online':loss.item(),
                        'Biou_online':Biou.item(),
                        'Fiou_online':Fiou.item(),
                        'Miou_online':miou.item(),
                        'time_complexicity_online':float(1/time_complexity),
                        'storage_complexicity_online':RAM_usagePeak.item()
                        }
            if(epoch - opt.unsave_epoch>=0):
                wandb.log(log_dict)

        toc_epoch = time.perf_counter()
        time_tensor = toc_epoch-tic_epoch


        summery_dict = manager_train.summary()
        log_train_end = {}
        for key in summery_dict:
            log_train_end[key+'_train_ave'] = summery_dict[key]
            print(key+'_train_ave: ',summery_dict[key])
        
        log_train_end['Time_PerEpoch'] = time_tensor
        if(epoch - opt.unsave_epoch>=0):
            wandb.log(log_train_end)
        else:
            print('No data upload to wandb. Start upload: Epoch[%d] Current: Epoch[%d]'%(opt.unsave_epoch,epoch))

        scheduler.step()
        if(epoch % 10 == 1):
            print('---------------------Validation----------------------')
            manager_test.reset()
            model.eval()
            print("Epoch: ",epoch)
            with torch.no_grad():
                for j, data in tqdm(enumerate(validation_loader), total=len(validation_loader), smoothing=0.9):


                    if(opt.model == 'deepgcn'):
                        points = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
                        points = points[:, :opt.num_channel, :, :]
                        target = data.y.cuda()
                    else:
                        points, target = data
                        #target.shape [B,N]
                        #points.shape [B,N,C]
                        points, target = points.cuda(non_blocking=True), target.cuda(non_blocking=True)


                    tic = time.perf_counter()
                    pred_mics = model(points)                
                    toc = time.perf_counter()
                    
                    #pred.shape [B,N,2] since pred returned pass F.log_softmax
                    pred, target = pred_mics[0].cpu(), target.cpu()

                    #compute loss
                    test_loss = 0

                    #pred:[B,N,2]->[B,N]
                    pred = pred.data.max(dim=2)[1]
                    #compute confusion matrix
                    cm = confusion_matrix(pred,target,num_classes =2).sum(dim=0)
                    #compute OA
                    overall_correct_site = torch.diag(cm).sum()
                    overall_reference_site = cm.sum()
                    # if(overall_reference_site != opt.batch_size * opt.num_points):
                    #pdb.set_trace()
                    #assert overall_reference_site == opt.batch_size * opt.num_points,"Confusion_matrix computing error"
                    oa = float(overall_correct_site/overall_reference_site)
                    
                    #compute iou
                    Biou,Fiou = mean_iou(pred,target,num_classes =2).mean(dim=0)
                    miou = (Biou+Fiou)/2

                    #compute inference time complexity
                    time_complexity = toc - tic
                    
                    #compute inference storage complexsity
                    num_device = torch.cuda.device_count()
                    assert num_device == opt.num_gpu,"opt.num_gpu NOT equals torch.cuda.device_count()" 
                    temp = []
                    for k in range(num_device):
                        temp.append(torch.cuda.memory_allocated(k))
                    RAM_usagePeak = torch.tensor(temp).float().mean()
                    #writeup logger
                    # metrics_list = ['test_loss','OA','Biou','Fiou','Miou','time_complexicity','storage_complexicity']
                    manager_test.update('loss',test_loss)
                    manager_test.update('OA',oa)
                    manager_test.update('Biou',Biou.item())
                    manager_test.update('Fiou',Fiou.item())
                    manager_test.update('Miou',miou.item())
                    manager_test.update('time_complexicity',float(1/time_complexity))
                    manager_test.update('storage_complexicity',RAM_usagePeak.item())

            
            summery_dict = manager_test.summary()

            log_val_end = {}
            for key in summery_dict:
                log_val_end[key+'_validation_ave'] = summery_dict[key]
                print(key+'_validation_ave: ',summery_dict[key])

            package = dict()
            package['state_dict'] = model.state_dict()
            package['scheduler'] = scheduler
            package['optimizer'] = optimizer
            package['epoch'] = epoch

            opt_temp = vars(opt)
            for k in opt_temp:
                package[k] = opt_temp[k]
            if(opt_deepgcn is not None):
                opt_temp = vars(opt_deepgcn)
                for k in opt_temp:
                    package[k+'_opt2'] = opt_temp[k]


            for k in log_val_end:
                package[k] = log_val_end[k]

            save_root = opt.save_root+'/val_miou%.4f_Epoch%s.pth'%(package['Miou_validation_ave'],package['epoch'])
            torch.save(package,save_root)

            print('Is Best?: ',(package['Miou_validation_ave']>best_value))
            if(package['Miou_validation_ave']>best_value):
                best_value = package['Miou_validation_ave']
                save_root = opt.save_root+'/best_model.pth'
                torch.save(package,save_root)
            if(epoch - opt.unsave_epoch>=0):
                wandb.log(log_val_end)
            else:
                print('No data upload to wandb. Start upload: Epoch[%d] Current: Epoch[%d]'%(opt.unsave_epoch,epoch))
            if(opt.debug == True):
                pdb.set_trace()
Exemplo n.º 17
0
        for i in range(0, len(dataset[test_mask]), args.batch_size)
    ]
    valid_max_nodes = [
        max([x.num_nodes for x in dataset[valid_mask][i:i + args.batch_size]])
        for i in range(0, len(dataset[valid_mask]), args.batch_size)
    ]

    train_max_nodes.append(
        max([x.num_nodes for x in dataset[train_mask][-args.batch_size - 1:]]))
    test_max_nodes.append(
        max([x.num_nodes for x in dataset[test_mask][-args.batch_size - 1:]]))
    valid_max_nodes.append(
        max([x.num_nodes for x in dataset[valid_mask][-args.batch_size - 1:]]))

    test_loader = DenseDataLoader(dataset[test_mask],
                                  batch_size=args.batch_size,
                                  shuffle=False)
    valid_loader = DenseDataLoader(dataset[valid_mask],
                                   batch_size=args.batch_size,
                                   shuffle=False)
    train_loader = DenseDataLoader(dataset[train_mask],
                                   batch_size=args.batch_size,
                                   shuffle=False)

    fold_best_acc = 0.
    fold_val_loss = 100000000.
    fold_val_acc = 0.
    patience = 0

    for epoch in range(1, 100001):
        t = time.time()
Exemplo n.º 18
0
def main():
    setSeed(10)
    opt = opt_global_inti()
    print('----------------------Load ckpt----------------------')
    pretrained_model_path = os.path.join(opt.load_pretrain,
                                         'saves/best_model.pth')
    package = torch.load(pretrained_model_path)
    para_state_dict = package['state_dict']
    opt.num_channel = package['num_channel']
    opt.time = package['time']
    opt.epoch_ckpt = package['epoch']
    try:
        state_dict = convert_state_dict(para_state_dict)
    except:
        para_state_dict = para_state_dict.state_dict()
        state_dict = convert_state_dict(para_state_dict)

    # state_dict = para_state_dict
    ckpt_, ckpt_file_name = opt.load_pretrain.split("/")
    module_name = ckpt_ + '.' + ckpt_file_name + '.' + 'model'
    MODEL = importlib.import_module(module_name)

    opt_deepgcn = []
    print(opt.model)
    if (opt.model == 'deepgcn'):
        opt_deepgcn = OptInit_deepgcn().initialize()
        model = MODEL.get_model(opt2=opt_deepgcn,
                                input_channel=opt.num_channel)
    else:
        # print('opt.num_channel: ',opt.num_channel)
        model = MODEL.get_model(input_channel=opt.num_channel,
                                is_synchoization='Instance')
    Model_Specification = MODEL.get_model_name(input_channel=opt.num_channel)
    print('----------------------Test Model----------------------')
    print('Root of prestrain model: ', pretrained_model_path)
    print('Model: ', opt.model)
    print('Pretrained model name: ', Model_Specification)
    print('Trained Date: ', opt.time)
    print('num_channel: ', opt.num_channel)
    name = input("Edit the name or press ENTER to skip: ")
    if (name != ''):
        opt.model_name = name
    else:
        opt.model_name = Model_Specification
    print('Pretrained model name: ', opt.model_name)
    package['name'] = opt.model_name
    try:
        package["Miou_validation_ave"] = package.pop("Validation_ave_miou")
    except:
        pass

    save_model(package, pretrained_model_path)
    #pdb.set_trace()

    #pdb.set_trace()
    # save_model(package,root,name)

    # if(model == 'pointnet'):
    #     #add args
    #     model = pointnet.Pointnet_sem_seg(k=2,num_channel=opt.num_channel)
    # elif(model == 'pointnetpp'):
    #     print()
    # elif(model == 'deepgcn'):
    #     print()
    # elif(model == 'dgcnn'):
    #     print()
    #pdb.set_trace()
    model.load_state_dict(state_dict)
    model.cuda()

    print('----------------------Load Dataset----------------------')
    print('Root of dataset: ', opt.dataset_root)
    print('Phase: ', opt.phase)
    print('debug: ', opt.debug)

    print('opt.model', opt.model)
    print(opt.model == 'deepgcn')
    if (opt.model != 'deepgcn'):
        test_dataset = BigredDataSet(root=opt.dataset_root,
                                     is_train=False,
                                     is_validation=False,
                                     is_test=True,
                                     num_channel=opt.num_channel,
                                     test_code=opt.debug,
                                     including_ring=opt.including_ring,
                                     file_name=opt.file_name)
        result_sheet = test_dataset.result_sheet
        file_dict = test_dataset.file_dict
        tag_Getter = tag_getter(file_dict)
        testloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 drop_last=True,
                                                 num_workers=int(
                                                     opt.num_workers))
    else:
        test_dataset = BigredDataSetPTG(
            root=opt.dataset_root,
            is_train=False,
            is_validation=False,
            is_test=True,
            num_channel=opt.num_channel,
            new_dataset=True,
            test_code=opt.debug,
            pre_transform=torch_geometric.transforms.NormalizeScale(),
            file_name=opt.file_name)
        result_sheet = test_dataset.result_sheet
        file_dict = test_dataset.file_dict
        print(file_dict)
        tag_Getter = tag_getter(file_dict)

        testloader = DenseDataLoader(test_dataset,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=opt.num_workers)

    print('num_frame: ', len(test_dataset))
    print('batch_size: ', opt.batch_size)
    print('num_batch: ', int(len(testloader) / opt.batch_size))

    print('----------------------Testing----------------------')
    metrics_list = [
        'Miou', 'Biou', 'Fiou', 'test_loss', 'OA', 'time_complexicity',
        'storage_complexicity'
    ]
    print(result_sheet)
    for name in result_sheet:
        metrics_list.append(name)
    print(metrics_list)
    manager = metrics_manager(metrics_list)

    model.eval()
    wandb.init(project="Test", name=package['name'])
    wandb.config.update(opt)

    prediction_set = []

    with torch.no_grad():
        for j, data in tqdm(enumerate(testloader),
                            total=len(testloader),
                            smoothing=0.9):
            #pdb.set_trace()
            if (opt.model == 'deepgcn'):
                points = torch.cat((data.pos.transpose(
                    2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)),
                                   1)
                points = points[:, :opt.num_channel, :, :].cuda()
                target = data.y.cuda()
            else:
                points, target = data
                #target.shape [B,N]
                #points.shape [B,N,C]
                points, target = points.cuda(), target.cuda()

            torch.cuda.synchronize()
            since = int(round(time.time() * 1000))
            pred_mics = model(points)
            torch.cuda.synchronize()
            #compute inference time complexity
            time_complexity = int(round(time.time() * 1000)) - since

            #print(time_complexity)

            #pred_mics[0] is pred
            #pred_mics[1] is feat [only pointnet and pointnetpp has it]

            #compute loss
            test_loss = 0

            #pred.shape [B,N,2] since pred returned pass F.log_softmax
            pred, target, points = pred_mics[0].cpu(), target.cpu(
            ), points.cpu()

            #pred:[B,N,2]->[B,N]
            # pdb.set_trace()
            pred = pred.data.max(dim=2)[1]
            prediction_set.append(pred)
            #compute confusion matrix
            cm = confusion_matrix(pred, target, num_classes=2).sum(dim=0)
            #compute OA
            overall_correct_site = torch.diag(cm).sum()
            overall_reference_site = cm.sum()
            assert overall_reference_site == opt.batch_size * opt.num_points, "Confusion_matrix computing error"
            oa = float(overall_correct_site / overall_reference_site)

            #compute iou
            Biou, Fiou = mean_iou(pred, target, num_classes=2).mean(dim=0)
            miou = (Biou + Fiou) / 2

            #compute inference storage complexsity
            num_device = torch.cuda.device_count()
            assert num_device == opt.num_gpu, "opt.num_gpu NOT equals torch.cuda.device_count()"
            temp = []
            for k in range(num_device):
                temp.append(torch.cuda.memory_allocated(k))
            RAM_usagePeak = torch.tensor(temp).float().mean()
            #writeup logger
            # metrics_list = ['test_loss','OA','Biou','Fiou','Miou','time_complexicity','storage_complexicity']
            manager.update('test_loss', test_loss)
            manager.update('OA', oa)
            manager.update('Biou', Biou.item())
            manager.update('Fiou', Fiou.item())
            manager.update('Miou', miou.item())
            manager.update('time_complexicity', time_complexity)
            manager.update('storage_complexicity', RAM_usagePeak.item())
            #get tags,compute the save miou for corresponding class
            difficulty, location, isSingle, file_name = tag_Getter.get_difficulty_location_isSingle(
                j)
            manager.update(file_name, miou.item())
            manager.update(difficulty, miou.item())
            manager.update(isSingle, miou.item())

    prediction_set = np.concatenate(prediction_set, axis=0)
    point_set, label_set, ermgering_set = test_dataset.getVis(prediction_set)
    #pdb.set_trace()

    experiment_dir = Path('visulization_data/' + opt.model)
    experiment_dir.mkdir(exist_ok=True)

    root = 'visulization_data/' + opt.model

    with open(root + '/point_set.npy', 'wb') as f:
        np.save(f, point_set)
    with open(root + '/label_set.npy', 'wb') as f:
        np.save(f, label_set)
    with open(root + '/ermgering_set.npy', 'wb') as f:
        np.save(f, ermgering_set)
    with open(root + '/prediction_set.npy', 'wb') as f:
        np.save(f, prediction_set)

    summery_dict = manager.summary()
    generate_report(summery_dict, package)
    wandb.log(summery_dict)
Exemplo n.º 19
0
def main():
    NUM_POINT = 20000
    opt = OptInit().initialize()
    opt.num_worker = 32
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpuNum

    opt.printer.info('===> Creating dataloader ...')

    train_dataset = BigredDataset(root = opt.train_path,
                                 is_train=True,
                                 is_validation=False,
                                 is_test=False,
                                 num_channel=opt.num_channel,
                                 pre_transform=T.NormalizeScale()
                                 )
    train_loader = DenseDataLoader(train_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_worker)

    validation_dataset = BigredDataset(root = opt.train_path,
                                 is_train=False,
                                 is_validation=True,
                                 is_test=False,
                                 num_channel=opt.num_channel,
                                 pre_transform=T.NormalizeScale()
                                 )
    validation_loader = DenseDataLoader(validation_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_worker)

    opt.printer.info('===> computing Labelweight ...')

    labelweights = np.zeros(2)
    labelweights, _ = np.histogram(train_dataset.data.y.numpy(), range(3))
    labelweights = labelweights.astype(np.float32)
    labelweights = labelweights / np.sum(labelweights)
    labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
    weights = torch.Tensor(labelweights).cuda()
    print("labelweights", weights)

    opt.n_classes = train_loader.dataset.num_classes

    opt.printer.info('===> Loading the network ...')

    opt.best_value = 0
    print("GPU:",opt.device)
    model = DenseDeepGCN(opt).to(opt.device)
    if opt.multi_gpus:
        model = DataParallel(DenseDeepGCN(opt)).to(device=opt.device)
    opt.printer.info('===> loading pre-trained ...')
    # model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase)

    opt.printer.info('===> Init the optimizer ...')
    criterion = torch.nn.CrossEntropyLoss(weight = weights).to(opt.device)
    # criterion_test = torch.nn.CrossEntropyLoss(weight = weights)

    if opt.optim.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    elif opt.optim.lower() == 'radam':
        optimizer = optim.RAdam(model.parameters(), lr=opt.lr)
    else:
        raise NotImplementedError('opt.optim is not supported')
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_adjust_freq, opt.lr_decay_rate)
    # optimizer, scheduler, opt.lr = load_pretrained_optimizer(opt.pretrained_model, optimizer, scheduler, opt.lr)

    opt.printer.info('===> Init Metric ...')
    opt.losses = AverageMeter()
    # opt.test_metric = miou
    opt.test_values = AverageMeter()
    opt.test_value = 0.

    opt.printer.info('===> start training ...')
    writer = SummaryWriter()
    writer_test = SummaryWriter()
    counter_test = 0
    counter_play = 0
    start_epoch = 0
    mean_miou = AverageMeter()
    mean_loss =  AverageMeter()
    mean_acc =  AverageMeter()
    best_value = 0
    for epoch in range(start_epoch, opt.total_epochs):
        opt.epoch += 1
        model.train()
        total_seen_class = [0 for _ in range(opt.n_classes)]
        total_correct_class = [0 for _ in range(opt.n_classes)]
        total_iou_deno_class = [0 for _ in range(opt.n_classes)]
        ave_mIoU = 0
        total_correct = 0
        total_seen = 0
        loss_sum = 0

        mean_miou.reset()
        mean_loss.reset()
        mean_acc.reset()


        for i, data in tqdm(enumerate(train_loader), total=len(train_loader), smoothing=0.9):
            # if i % 50 == 0:
            opt.iter += 1
            if not opt.multi_gpus:
                data = data.to(opt.device)
            target = data.y
            batch_label2 = target.cpu().data.numpy()
            inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
            inputs = inputs[:, :opt.num_channel, :, :]
            gt = data.y.to(opt.device)
            # ------------------ zero, output, loss
            optimizer.zero_grad()
            out = model(inputs)

            loss = criterion(out, gt)
            #pdb.set_trace()

            # ------------------ optimization
            loss.backward()
            optimizer.step()

            seg_pred= out.transpose(2,1)

            pred_val = seg_pred.contiguous().cpu().data.numpy()
            seg_pred = seg_pred.contiguous().view(-1, opt.n_classes)
            #pdb.set_trace()
            pred_val = np.argmax(pred_val, 2)
            batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy()
            target = target.view(-1, 1)[:, 0]
            pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
            correct = np.sum(pred_choice == batch_label)

            total_correct += correct
            total_seen += (opt.batch_size *NUM_POINT)
            loss_sum += loss.item()

            current_seen_class = [0 for _ in range(opt.n_classes)]
            current_correct_class = [0 for _ in range(opt.n_classes)]
            current_iou_deno_class = [0 for _ in range(opt.n_classes)]
            #pdb.set_trace()

            for l in range(opt.n_classes):
                #pdb.set_trace()
                total_seen_class[l] += np.sum((batch_label2 == l))
                total_correct_class[l] += np.sum((pred_val == l) & (batch_label2 == l))
                total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label2 == l)))
                current_seen_class[l] = np.sum((batch_label2 == l))
                current_correct_class[l] = np.sum((pred_val == l) & (batch_label2 == l))
                current_iou_deno_class[l] = np.sum(((pred_val == l) | (batch_label2 == l)))

            #pdb.set_trace()
            writer.add_scalar('training_loss', loss.item(), counter_play)
            writer.add_scalar('training_accuracy', correct / float(opt.batch_size * NUM_POINT), counter_play)
            m_iou = np.mean(np.array(current_correct_class) / (np.array(current_iou_deno_class, dtype=np.float) + 1e-6))
            writer.add_scalar('training_mIoU', m_iou, counter_play)
            ave_mIoU = np.mean(np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6))

            # print("training_loss:",loss.item())
            # print('training_accuracy:',correct / float(opt.batch_size * NUM_POINT))
            # print('training_mIoU:',m_iou)

            mean_miou.update(m_iou)
            mean_loss.update(loss.item())
            mean_acc.update(correct / float(opt.batch_size * NUM_POINT))

            counter_play = counter_play + 1

        train_mIoU = mean_miou.avg
        train_macc = mean_acc.avg
        train_mloss = mean_loss.avg

        print('Epoch: %d, Training point avg class IoU: %f' % (epoch,train_mIoU))
        print('Epoch: %d, Training mean loss: %f' %(epoch, train_mloss))
        print('Epoch: %d, Training accuracy: %f' %(epoch, train_macc))

        mean_miou.reset()
        mean_loss.reset()
        mean_acc.reset()

        print('validation_loader')

        model.eval()
        with torch.no_grad():
            for i, data in tqdm(enumerate(validation_loader), total=len(validation_loader), smoothing=0.9):
                # if i % 50 ==0:
                if not opt.multi_gpus:
                    data = data.to(opt.device)

                target = data.y
                batch_label2 = target.cpu().data.numpy()

                inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
                inputs = inputs[:, :opt.num_channel, :, :]
                gt = data.y.to(opt.device)
                out = model(inputs)
                loss = criterion(out, gt)
                #pdb.set_trace()

                seg_pred = out.transpose(2, 1)
                pred_val = seg_pred.contiguous().cpu().data.numpy()
                seg_pred = seg_pred.contiguous().view(-1, opt.n_classes)
                pred_val = np.argmax(pred_val, 2)
                batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy()
                target = target.view(-1, 1)[:, 0]
                pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
                correct = np.sum(pred_choice == batch_label)
                current_seen_class = [0 for _ in range(opt.n_classes)]
                current_correct_class = [0 for _ in range(opt.n_classes)]
                current_iou_deno_class = [0 for _ in range(opt.n_classes)]
                for l in range(opt.n_classes):

                    current_seen_class[l] = np.sum((batch_label2 == l))
                    current_correct_class[l] = np.sum((pred_val == l) & (batch_label2 == l))
                    current_iou_deno_class[l] = np.sum(((pred_val == l) | (batch_label2 == l)))
                m_iou = np.mean(
                    np.array(current_correct_class) / (np.array(current_iou_deno_class, dtype=np.float) + 1e-6))
                mean_miou.update(m_iou)
                mean_loss.update(loss.item())
                mean_acc.update(correct / float(opt.batch_size * NUM_POINT))

        validation_mIoU = mean_miou.avg
        validation_macc = mean_acc.avg
        validation_mloss = mean_loss.avg
        writer.add_scalar('validation_loss', validation_mloss, epoch)
        print('Epoch: %d, validation mean loss: %f' %(epoch, validation_mloss))
        writer.add_scalar('validation_accuracy', validation_macc, epoch)
        print('Epoch: %d, validation accuracy: %f' %(epoch, validation_macc))
        writer.add_scalar('validation_mIoU', validation_mIoU, epoch)
        print('Epoch: %d, validation point avg class IoU: %f' % (epoch,validation_mIoU))

        model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
        package ={
        'epoch': opt.epoch,
        'state_dict': model_cpu,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_miou':train_mIoU,
        'train_accuracy':train_macc,
        'train_loss':train_mloss,
        'validation_mIoU':validation_mIoU,
        'validation_macc':validation_macc,
        'validation_mloss':validation_mloss,
        'num_channel':opt.num_channel,
        'gpuNum': opt.gpuNum,
        'time':time.ctime()
        }
        torch.save(package,'saves/val_miou_%f_val_acc_%f_%d.pth' % (validation_mIoU, validation_macc, epoch))
        is_best = (best_value < validation_mIoU)
        print('Is Best? ',is_best)
        if (best_value < validation_mIoU):
            best_value = validation_mIoU
            torch.save(package,'saves/best_model.pth')
        print('Best IoU: %f' % (best_value))
        scheduler.step()
    opt.printer.info('Saving the final model.Finish!')
Exemplo n.º 20
0
if args.data == 'graphaf':
    with open('config/cons_optim_graphaf_config_dict.json') as f:
        conf = json.load(f)
    dataset = ZINC800(method='graphaf',
                      conf_dict=conf['data'],
                      one_shot=False,
                      use_aug=False)
else:
    print('Only graphaf datasets are supported!')
    exit()

runner = GraphAF()

if args.train:
    loader = DenseDataLoader(dataset,
                             batch_size=conf['batch_size'],
                             shuffle=True)
    runner.train_cons_optim(loader, conf['lr'], conf['weight_decay'],
                            conf['max_iters'], conf['warm_up'], conf['model'],
                            conf['pretrain_model'], conf['save_interval'],
                            conf['save_dir'])
else:
    mols_0, mols_2, mols_4, mols_6 = runner.run_cons_optim(
        dataset, conf['model'], args.model_path, conf['repeat_time'],
        conf['min_optim_time'], conf['num_max_node'], conf['temperature'],
        conf['atom_list'])
    smiles = [data.smile for data in dataset]
    evaluator = Cons_Optim_Evaluator()
    input_dict = {
        'mols_0': mols_0,
        'mols_2': mols_2,
Exemplo n.º 21
0
                     feature=args.feature,
                     empty=False,
                     name=args.dataset,
                     transform=T.ToDense(max_nodes),
                     pre_transform=ToUndirected())

print(args)

num_training = int(len(dataset) * 0.2)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - (num_training + num_val)
training_set, validation_set, test_set = random_split(
    dataset, [num_training, num_val, num_test])

train_loader = DenseDataLoader(training_set,
                               batch_size=args.batch_size,
                               shuffle=True)
val_loader = DenseDataLoader(validation_set,
                             batch_size=args.batch_size,
                             shuffle=False)
test_loader = DenseDataLoader(test_set,
                              batch_size=args.batch_size,
                              shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(in_channels=dataset.num_features,
            num_classes=dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in tqdm(range(args.epochs)):
    [acc_train, _, _, _, recall_train, auc_train, _], loss_train = train()
Exemplo n.º 22
0
                    cur_shape_iou_tot += I/U
                    cur_shape_iou_cnt += 1.

            if cur_shape_iou_cnt > 0:
                cur_shape_miou = cur_shape_iou_tot / cur_shape_iou_cnt
                shape_iou_tot += cur_shape_miou
                shape_iou_cnt += 1.

    shape_mIoU = shape_iou_tot / shape_iou_cnt
    part_iou = np.divide(part_intersect[1:], part_union[1:])
    mean_part_iou = np.mean(part_iou)
    opt.printer.info("===> Category {}-{}, Part mIOU is{:.4f} \t ".format(
                      opt.category_no, opt.category, mean_part_iou))


if __name__ == '__main__':
    opt = OptInit().initialize()
    opt.printer.info('===> Creating dataloader ...')
    test_dataset = PartNet(opt.data_dir, opt.dataset, opt.category, opt.level, 'val')
    test_loader = DenseDataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1)
    opt.n_classes = test_loader.dataset.num_classes

    opt.printer.info('===> Loading the network ...')
    model = DenseDeepGCN(opt).to(opt.device)
    opt.printer.info('===> loading pre-trained ...')
    model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase)

    test(model, test_loader, opt)


Exemplo n.º 23
0
def main():
    setSeed(10)
    opt = opt_global_inti()
    print('----------------------Load ckpt----------------------')
    pretrained_model_path = os.path.join(opt.load_pretrain, 'best_model.pth')
    package = torch.load(pretrained_model_path)
    para_state_dict = package['state_dict']
    opt.num_channel = package['num_channel']
    opt.time = package['time']
    opt.epoch_ckpt = package['epoch']
    #pdb.set_trace()
    state_dict = convert_state_dict(para_state_dict)

    ckpt_, ckpt_file_name = opt.load_pretrain.split("/")
    module_name = ckpt_ + '.' + ckpt_file_name + '.' + 'model'
    MODEL = importlib.import_module(module_name)
    opt_deepgcn = []
    print(opt.model)
    if (opt.model == 'deepgcn'):
        opt_deepgcn = OptInit_deepgcn().initialize()
        model = MODEL.get_model(opt2=opt_deepgcn,
                                input_channel=opt.num_channel)
    else:
        # print('opt.num_channel: ',opt.num_channel)
        model = MODEL.get_model(input_channel=opt.num_channel)
    Model_Specification = MODEL.get_model_name(input_channel=opt.num_channel)
    f_loss = MODEL.get_loss(input_channel=opt.num_channel)

    print('----------------------Test Model----------------------')
    print('Root of prestrain model: ', pretrained_model_path)
    print('Model: ', opt.model)
    print('Pretrained model name: ', Model_Specification)
    print('Trained Date: ', opt.time)
    print('num_channel: ', opt.num_channel)
    name = input("Edit the name or press ENTER to skip: ")
    if (name != ''):
        opt.model_name = name
    else:
        opt.model_name = Model_Specification
    print('Pretrained model name: ', opt.model_name)
    package['name'] = opt.model_name
    save_model(package, pretrained_model_path)

    print(
        '----------------------Configure optimizer and scheduler----------------------'
    )
    experiment_dir = Path('ckpt/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath(opt.model_name)
    experiment_dir.mkdir(exist_ok=True)

    experiment_dir = experiment_dir.joinpath('saves')
    experiment_dir.mkdir(exist_ok=True)
    opt.save_root = str(experiment_dir)

    model.ini_ft()
    model.frozen_ft()

    if (opt.apex == True):
        # model = apex.parallel.convert_syncbn_model(model)
        model.cuda()
        f_loss.cuda()

        optimizer = optim.Adam(model.parameters(),
                               lr=0.001,
                               betas=(0.9, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=1,
                                              gamma=0.1)
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
        model = torch.nn.DataParallel(model, device_ids=[0, 1])
    else:
        # model = apex.parallel.convert_syncbn_model(model)
        model = torch.nn.DataParallel(model)
        model.cuda()
        f_loss.cuda()
        # optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
        # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
        # optimizer = package['optimizer']
        # scheduler = package['scheduler']

    # optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
    # optimizer_dict = package['optimizer'].state_dict()
    # optimizer.load_state_dict(optimizer_dict)
    # scheduler = package['scheduler']

    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

    print('----------------------Load Dataset----------------------')
    print('Root of dataset: ', opt.dataset_root)
    print('Phase: ', opt.phase)
    print('debug: ', opt.debug)

    if (opt.model != 'deepgcn'):
        train_dataset = BigredDataSet_finetune(
            root=opt.dataset_root,
            is_train=True,
            is_validation=False,
            is_test=False,
            num_channel=opt.num_channel,
            test_code=opt.debug,
            including_ring=opt.including_ring)

        f_loss.load_weight(train_dataset.labelweights)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   drop_last=True,
                                                   num_workers=int(
                                                       opt.num_workers))

        validation_dataset = BigredDataSet_finetune(
            root=opt.dataset_root,
            is_train=False,
            is_validation=True,
            is_test=False,
            num_channel=opt.num_channel,
            test_code=opt.debug,
            including_ring=opt.including_ring)
        validation_loader = torch.utils.data.DataLoader(
            validation_dataset,
            batch_size=opt.batch_size,
            shuffle=False,
            pin_memory=True,
            drop_last=True,
            num_workers=int(opt.num_workers))
    else:
        train_dataset = BigredDataSetPTG(
            root=opt.dataset_root,
            is_train=True,
            is_validation=False,
            is_test=False,
            num_channel=opt.num_channel,
            new_dataset=False,
            test_code=opt.debug,
            pre_transform=torch_geometric.transforms.NormalizeScale())
        train_loader = DenseDataLoader(train_dataset,
                                       batch_size=opt.batch_size,
                                       shuffle=True,
                                       num_workers=opt.num_workers)
        validation_dataset = BigredDataSetPTG(
            root=opt.dataset_root,
            is_train=False,
            is_validation=True,
            is_test=False,
            new_dataset=False,
            test_code=opt.debug,
            num_channel=opt.num_channel,
            pre_transform=torch_geometric.transforms.NormalizeScale())
        validation_loader = DenseDataLoader(validation_dataset,
                                            batch_size=opt.batch_size,
                                            shuffle=False,
                                            num_workers=opt.num_workers)

        labelweights = np.zeros(2)
        labelweights, _ = np.histogram(train_dataset.data.y.numpy(), range(3))
        labelweights = labelweights.astype(np.float32)
        labelweights = labelweights / np.sum(labelweights)
        labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
        weights = torch.Tensor(labelweights).cuda()
        f_loss.load_weight(weights)

    print('train dataset num_frame: ', len(train_dataset))
    print('num_batch: ', int(len(train_loader) / opt.batch_size))

    print('validation dataset num_frame: ', len(validation_dataset))
    print('num_batch: ', int(len(validation_loader) / opt.batch_size))

    print('Batch_size: ', opt.batch_size)

    print('----------------------Prepareing Training----------------------')
    metrics_list = [
        'Miou', 'Biou', 'Fiou', 'loss', 'OA', 'time_complexicity',
        'storage_complexicity'
    ]
    manager_test = metrics_manager(metrics_list)

    metrics_list_train = [
        'Miou', 'Biou', 'Fiou', 'loss', 'storage_complexicity',
        'time_complexicity'
    ]
    manager_train = metrics_manager(metrics_list_train)

    wandb.init(project=opt.wd_project, name=opt.model_name, resume=False)
    if (opt.wandb_history == False):
        best_value = 0
    else:
        temp = wandb.restore('best_model.pth', run_path=opt.wandb_id)
        best_value = torch.load(temp.name)['Miou_validation_ave']

    best_value = 0
    wandb.config.update(opt)

    if opt.epoch_ckpt == 0:
        opt.unsave_epoch = 0
    else:
        opt.epoch_ckpt = opt.epoch_ckpt + 1

    # pdb.set_trace()
    for epoch in range(opt.epoch_ckpt, opt.epoch_max):
        manager_train.reset()
        model.train()
        tic_epoch = time.perf_counter()
        print('---------------------Training----------------------')
        print("Epoch: ", epoch)
        for i, data in tqdm(enumerate(train_loader),
                            total=len(train_loader),
                            smoothing=0.9):

            if (opt.model == 'deepgcn'):
                points = torch.cat((data.pos.transpose(
                    2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)),
                                   1)
                points = points[:, :opt.num_channel, :, :]
                target = data.y.cuda()
            else:
                points, target = data
                #target.shape [B,N]
                #points.shape [B,N,C]
                points, target = points.cuda(non_blocking=True), target.cuda(
                    non_blocking=True)

            # pdb.set_trace()
            #training...
            optimizer.zero_grad()
            tic = time.perf_counter()
            pred_mics = model(points)
            toc = time.perf_counter()
            #compute loss

            #For loss
            #target.shape [B,N] ->[B*N]
            #pred.shape [B,N,2]->[B*N,2]
            #pdb.set_trace()

            #pdb.set_trace()
            loss = f_loss(pred_mics, target)

            if (opt.apex):
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()

            #pred.shape [B,N,2] since pred returned pass F.log_softmax
            pred, target = pred_mics[0].cpu(), target.cpu()

            #pred:[B,N,2]->[B,N]
            #pdb.set_trace()
            pred = pred.data.max(dim=2)[1]

            #compute iou
            Biou, Fiou = mean_iou(pred, target, num_classes=2).mean(dim=0)
            miou = (Biou + Fiou) / 2

            #compute Training time complexity
            time_complexity = toc - tic

            #compute Training storage complexsity
            num_device = torch.cuda.device_count()
            assert num_device == opt.num_gpu, "opt.num_gpu NOT equals torch.cuda.device_count()"
            temp = []
            for k in range(num_device):
                temp.append(torch.cuda.memory_allocated(k))
            RAM_usagePeak = torch.tensor(temp).float().mean()

            #print(loss.item())
            #print(miou.item())
            #writeup logger
            manager_train.update('loss', loss.item())
            manager_train.update('Biou', Biou.item())
            manager_train.update('Fiou', Fiou.item())
            manager_train.update('Miou', miou.item())
            manager_train.update('time_complexicity',
                                 float(1 / time_complexity))
            manager_train.update('storage_complexicity', RAM_usagePeak.item())

            log_dict = {
                'loss_online': loss.item(),
                'Biou_online': Biou.item(),
                'Fiou_online': Fiou.item(),
                'Miou_online': miou.item(),
                'time_complexicity_online': float(1 / time_complexity),
                'storage_complexicity_online': RAM_usagePeak.item()
            }
            if (epoch - opt.unsave_epoch >= 0):
                wandb.log(log_dict)

        toc_epoch = time.perf_counter()
        time_tensor = toc_epoch - tic_epoch

        summery_dict = manager_train.summary()
        log_train_end = {}
        for key in summery_dict:
            log_train_end[key + '_train_ave'] = summery_dict[key]
            print(key + '_train_ave: ', summery_dict[key])

        log_train_end['Time_PerEpoch'] = time_tensor
        if (epoch - opt.unsave_epoch >= 0):
            wandb.log(log_train_end)
        else:
            print(
                'No data upload to wandb. Start upload: Epoch[%d] Current: Epoch[%d]'
                % (opt.unsave_epoch, epoch))

        scheduler.step()
        if (epoch % 10 == 1):
            print('---------------------Validation----------------------')
            manager_test.reset()
            model.eval()
            print("Epoch: ", epoch)
            with torch.no_grad():
                for j, data in tqdm(enumerate(validation_loader),
                                    total=len(validation_loader),
                                    smoothing=0.9):

                    if (opt.model == 'deepgcn'):
                        points = torch.cat(
                            (data.pos.transpose(2, 1).unsqueeze(3),
                             data.x.transpose(2, 1).unsqueeze(3)), 1)
                        points = points[:, :opt.num_channel, :, :]
                        target = data.y.cuda()
                    else:
                        points, target = data
                        #target.shape [B,N]
                        #points.shape [B,N,C]
                        points, target = points.cuda(
                            non_blocking=True), target.cuda(non_blocking=True)

                    tic = time.perf_counter()
                    pred_mics = model(points)
                    toc = time.perf_counter()

                    #pred.shape [B,N,2] since pred returned pass F.log_softmax
                    pred, target = pred_mics[0].cpu(), target.cpu()

                    #compute loss
                    test_loss = 0

                    #pred:[B,N,2]->[B,N]
                    pred = pred.data.max(dim=2)[1]
                    #compute confusion matrix
                    cm = confusion_matrix(pred, target,
                                          num_classes=2).sum(dim=0)
                    #compute OA
                    overall_correct_site = torch.diag(cm).sum()
                    overall_reference_site = cm.sum()
                    # if(overall_reference_site != opt.batch_size * opt.num_points):
                    #pdb.set_trace()
                    #assert overall_reference_site == opt.batch_size * opt.num_points,"Confusion_matrix computing error"
                    oa = float(overall_correct_site / overall_reference_site)

                    #compute iou
                    Biou, Fiou = mean_iou(pred, target,
                                          num_classes=2).mean(dim=0)
                    miou = (Biou + Fiou) / 2

                    #compute inference time complexity
                    time_complexity = toc - tic

                    #compute inference storage complexsity
                    num_device = torch.cuda.device_count()
                    assert num_device == opt.num_gpu, "opt.num_gpu NOT equals torch.cuda.device_count()"
                    temp = []
                    for k in range(num_device):
                        temp.append(torch.cuda.memory_allocated(k))
                    RAM_usagePeak = torch.tensor(temp).float().mean()
                    #writeup logger
                    # metrics_list = ['test_loss','OA','Biou','Fiou','Miou','time_complexicity','storage_complexicity']
                    manager_test.update('loss', test_loss)
                    manager_test.update('OA', oa)
                    manager_test.update('Biou', Biou.item())
                    manager_test.update('Fiou', Fiou.item())
                    manager_test.update('Miou', miou.item())
                    manager_test.update('time_complexicity',
                                        float(1 / time_complexity))
                    manager_test.update('storage_complexicity',
                                        RAM_usagePeak.item())

            summery_dict = manager_test.summary()

            log_val_end = {}
            for key in summery_dict:
                log_val_end[key + '_validation_ave'] = summery_dict[key]
                print(key + '_validation_ave: ', summery_dict[key])

            package = dict()
            package['state_dict'] = model.state_dict()
            package['scheduler'] = scheduler
            package['optimizer'] = optimizer
            package['epoch'] = epoch

            opt_temp = vars(opt)
            for k in opt_temp:
                package[k] = opt_temp[k]
            if (opt_deepgcn is None):
                opt_temp = vars(opt_deepgcn)
                for k in opt_temp:
                    package[k + '_opt2'] = opt_temp[k]

            for k in log_val_end:
                package[k] = log_val_end[k]

            save_root = opt.save_root + '/val_miou%.4f_Epoch%s.pth' % (
                package['Miou_validation_ave'], package['epoch'])
            torch.save(package, save_root)

            print('Is Best?: ', (package['Miou_validation_ave'] > best_value))
            if (package['Miou_validation_ave'] > best_value):
                best_value = package['Miou_validation_ave']
                save_root = opt.save_root + '/best_model.pth'
                torch.save(package, save_root)
            if (epoch - opt.unsave_epoch >= 0):
                wandb.log(log_val_end)
            else:
                print(
                    'No data upload to wandb. Start upload: Epoch[%d] Current: Epoch[%d]'
                    % (opt.unsave_epoch, epoch))
            if (opt.debug == True):
                pdb.set_trace()
Exemplo n.º 24
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    if args.random_seed:
        args.seed = np.random.randint(0, 1000, 1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    # dataset modelnet
    pre_transform, transform = T.NormalizeScale(), T.SamplePoints(
        args.num_points)
    train_dataset = GeoData.ModelNet(os.path.join(args.data, 'modelnet10'),
                                     '10', True, transform, pre_transform)
    train_queue = DenseDataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.batch_size // 2)
    test_dataset = GeoData.ModelNet(os.path.join(args.data, 'modelnet10'),
                                    '10', False, transform, pre_transform)
    valid_queue = DenseDataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.batch_size // 2)
    n_classes = train_queue.dataset.num_classes

    criterion = torch.nn.CrossEntropyLoss().cuda()
    model = Network(args.init_channels,
                    n_classes,
                    args.num_cells,
                    criterion,
                    args.n_steps,
                    in_channels=args.in_channels,
                    emb_dims=args.emb_dims,
                    dropout=args.dropout,
                    k=args.k).cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    num_edges = model._steps * 2
    post_train = 5
    # import pdb;pdb.set_trace()
    args.epochs = args.warmup_dec_epoch + args.decision_freq * (
        num_edges - 1) + post_train + 1
    logging.info("total epochs: %d", args.epochs)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    normal_selected_idxs = torch.tensor(len(model.alphas_normal) * [-1],
                                        requires_grad=False,
                                        dtype=torch.int).cuda()
    normal_candidate_flags = torch.tensor(len(model.alphas_normal) * [True],
                                          requires_grad=False,
                                          dtype=torch.bool).cuda()
    logging.info('normal_selected_idxs: {}'.format(normal_selected_idxs))
    logging.info('normal_candidate_flags: {}'.format(normal_candidate_flags))
    model.normal_selected_idxs = normal_selected_idxs
    model.normal_candidate_flags = normal_candidate_flags

    print(F.softmax(torch.stack(model.alphas_normal, dim=0), dim=-1).detach())

    count = 0
    normal_probs_history = []
    train_losses, valid_losses = utils.AverageMeter(), utils.AverageMeter()
    for epoch in range(args.epochs):
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)
        # training
        # import pdb;pdb.set_trace()
        att = model.show_att()
        beta = model.show_beta()
        train_acc, train_losses = train(train_queue, valid_queue, model,
                                        architect, criterion, optimizer, lr,
                                        train_losses)
        valid_overall_acc, valid_class_acc, valid_losses = infer(
            valid_queue, model, criterion, valid_losses)

        logging.info(
            'train_acc %f\tvalid_overall_acc %f \t valid_class_acc %f',
            train_acc, valid_overall_acc, valid_class_acc)
        logging.info('beta %s', beta.cpu().detach().numpy())
        logging.info('att %s', att.cpu().detach().numpy())
        # make edge decisions
        saved_memory_normal, model.normal_selected_idxs, \
        model.normal_candidate_flags = edge_decision('normal',
                                                     model.alphas_normal,
                                                     model.normal_selected_idxs,
                                                     model.normal_candidate_flags,
                                                     normal_probs_history,
                                                     epoch,
                                                     model,
                                                     args)

        if saved_memory_normal:
            del train_queue, valid_queue
            torch.cuda.empty_cache()

            count += 1
            new_batch_size = args.batch_size + args.batch_increase * count
            logging.info("new_batch_size = {}".format(new_batch_size))
            train_queue = DenseDataLoader(train_dataset,
                                          batch_size=new_batch_size,
                                          shuffle=True,
                                          num_workers=args.batch_size // 2)
            valid_queue = DenseDataLoader(test_dataset,
                                          batch_size=new_batch_size,
                                          shuffle=False,
                                          num_workers=args.batch_size // 2)
            # post validation
            if args.post_val:
                post_valid_overall_acc, post_valid_class_acc, valid_losses = infer(
                    valid_queue, model, criterion, valid_losses)
                logging.info('post_valid_overall_acc %f',
                             post_valid_overall_acc)

        writer.add_scalar('stats/train_acc', train_acc, epoch)
        writer.add_scalar('stats/valid_overall_acc', valid_overall_acc, epoch)
        writer.add_scalar('stats/valid_class_acc', valid_class_acc, epoch)
        utils.save(model, os.path.join(args.save, 'weights.pt'))
        scheduler.step()

    logging.info("#" * 30 + " Done " + "#" * 30)
    logging.info('genotype = %s', model.get_genotype())
        """
        sample = []
        idx = self.files.index(self.allowed[idx])
        for i in range(self.sequence_length):
            sample.append(np.load(os.path.join(self.path, self.files[idx + i])))
        data = [image["data"] for image in sample]
        if self.prediction_shift > 0:
            target = np.load(os.path.join(self.path, self.files[idx + i + self.prediction_shift]))
            labels = target["labels"]
        else:
            labels = sample[-1]["labels"]
        if self.transform_image:
            for i, image in enumerate(data):
                data[i] = self.transform_image(image)
        if self.transform_labels:
            labels = self.transform_labels(labels)
        if self.transform_sample:
            data = self.transform_sample(data)
        return data, labels


if __name__ == '__main__':
    from torch_geometric.data import DenseDataLoader

    dataset = ARTCDataset('/home/joey_yu/Datasets/data_5_all')
    loader = DenseDataLoader(dataset[:100], 4, shuffle=True, num_workers=0)

    print(len(dataset))
    for x in loader:
        print(x)