コード例 #1
0
 def load_pretrained_weights(self, fname):
     if torch.cuda.is_available():
         checkpoint = torch.load(fname)
     else:
         checkpoint = torch.load(fname, map_location=torch.device('cpu'))
     model_weights = self.state_dict()
     model_weights.update({k: v for k, v in checkpoint['state_dict'].items()
                           if k in model_weights})
     self.load_state_dict(model_weights)
コード例 #2
0
 def load(self, step):
     fname = self.fname_template.format(step)
     if not os.path.exists(fname):
         print(fname + ' does not exist!')
         return
     print('Loading checkpoint from %s...' % fname)
     if porch.cuda.is_available():
         module_dict = porch.load(fname)
     else:
         module_dict = porch.load(fname, map_location=porch.device('cpu'))
     for name, module in self.module_dict.items():
         if name in module_dict:
             print(name,"loaded")
             module.load_state_dict(module_dict[name])
コード例 #3
0
ファイル: inception.py プロジェクト: zzz2010/Contrib
def inception_v3(pretrained_fn=None, progress=True, **kwargs):
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained_fn is not None:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
        model = Inception3(**kwargs)
        state_dict = torch.load(pretrained_fn)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
        return model

    return Inception3(**kwargs)
コード例 #4
0
ファイル: wing.py プロジェクト: zzz2010/paddorch
    def load_pretrained_weights(self, fname):

        checkpoint = torch.load(fname)

        # model_weights = self.state_dict()
        # model_weights.update({k: v for k, v in checkpoint['state_dict'].items()
        #                       if k in model_weights})
        self.load_state_dict(checkpoint)
コード例 #5
0
        def _load_lpips_weights(self):
            own_state_dict = self.state_dict()

            state_dict = torch.load('lpips_weights.ckpt',
                                        map_location=torch.device('cpu'))
            for name, param in state_dict.items():
                if name in own_state_dict:
                    own_state_dict[name].copy_(param)
コード例 #6
0
 def load(self, dir, step):
     params = torch.load(
         os.path.join(dir, self.dataset + '_params_%07d.pt' % step))
     self.genA2B.set_dict(params['genA2B'])
     self.genB2A.set_dict(params['genB2A'])
     self.disGA.set_dict(params['disGA'])
     self.disGB.set_dict(params['disGB'])
     self.disLA.set_dict(params['disLA'])
     self.disLB.set_dict(params['disLB'])
コード例 #7
0
def load_inception_net(parallel=False):
    inception_model = inception_v3()
    inception_model.set_dict(torch.load("inception_model.pdparams"))
    inception_model.eval()
    inception_model = WrapInception(inception_model)

    if parallel:
        print('Parallelizing Inception module...')
        inception_model = nn.DataParallel(inception_model)
    return inception_model
コード例 #8
0
ファイル: alexnet.py プロジェクト: zzz2010/Contrib
def alexnet(pretrained_file=None, progress=True, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = AlexNet(**kwargs)
    if pretrained_file is not None:
        state_dict = torch.load(pretrained_file)
        model.load_state_dict(state_dict)
    return model
コード例 #9
0
ファイル: inception.py プロジェクト: zzz2010/Contrib
 def __init__(self, pretrained_fn=None):
     super().__init__()
     inception = inception_v3()
     self.block1 = nn.Sequential(inception.Conv2d_1a_3x3,
                                 inception.Conv2d_2a_3x3,
                                 inception.Conv2d_2b_3x3,
                                 nn.MaxPool2d(kernel_size=3, stride=2))
     self.block2 = nn.Sequential(inception.Conv2d_3b_1x1,
                                 inception.Conv2d_4a_3x3,
                                 nn.MaxPool2d(kernel_size=3, stride=2))
     self.block3 = nn.Sequential(inception.Mixed_5b, inception.Mixed_5c,
                                 inception.Mixed_5d, inception.Mixed_6a,
                                 inception.Mixed_6b, inception.Mixed_6c,
                                 inception.Mixed_6d, inception.Mixed_6e)
     self.block4 = nn.Sequential(inception.Mixed_7a, inception.Mixed_7b,
                                 inception.Mixed_7c,
                                 nn.AdaptiveAvgPool2d(output_size=(1, 1)))
     if pretrained_fn is not None:
         self.load_state_dict(torch.load(pretrained_fn))
コード例 #10
0
def main(args_test):
    if os.path.isfile(args_test.load_path):
        print("=> loading checkpoint '{}'".format(args_test.load_path))
        checkpoint = torch.load(args_test.load_path, map_location="cpu")
        print(
            "=> loaded successfully '{}' (epoch {})".format(
                args_test.load_path, checkpoint["epoch"]
            )
        )
    else:
        print("=> no checkpoint found at '{}'".format(args_test.load_path))
    args = checkpoint["opt"]

    assert args_test.gpu is None or torch.cuda.is_available()
    print("Use GPU: {} for generation".format(args_test.gpu))
    args.gpu = args_test.gpu
    args.device = torch.device("cpu") if args.gpu is None else torch.device(args.gpu)

    if args_test.dataset in GRAPH_CLASSIFICATION_DSETS:
        train_dataset = GraphClassificationDataset(
            dataset=args_test.dataset,
            rw_hops=args.rw_hops,
            subgraph_size=args.subgraph_size,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
        )
    else:
        train_dataset = NodeClassificationDataset(
            dataset=args_test.dataset,
            rw_hops=args.rw_hops,
            subgraph_size=args.subgraph_size,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
        )
    args.batch_size = len(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        collate_fn=batcher(),
        shuffle=False,
        num_workers=args.num_workers,
    )

    # create model and optimizer
    model = GraphEncoder(
        positional_embedding_size=args.positional_embedding_size,
        max_node_freq=args.max_node_freq,
        max_edge_freq=args.max_edge_freq,
        max_degree=args.max_degree,
        freq_embedding_size=args.freq_embedding_size,
        degree_embedding_size=args.degree_embedding_size,
        output_dim=args.hidden_size,
        node_hidden_dim=args.hidden_size,
        edge_hidden_dim=args.hidden_size,
        num_layers=args.num_layer,
        num_step_set2set=args.set2set_iter,
        num_layer_set2set=args.set2set_lstm_layer,
        gnn_model=args.model,
        norm=args.norm,
        degree_input=True,
    )

    model = model.to(args.device)

    model.load_state_dict(checkpoint["model"])

    del checkpoint

    emb = test_moco(train_loader, model, args)
    np.save(os.path.join(args.model_folder, args_test.dataset), emb.numpy())
コード例 #11
0
shuffle_indices = np.random.choice(len(y_train), size=len(y_train) * 5)
y_train = y_train[shuffle_indices]
z_train = np.concatenate([vv[0] for vv in z_y_m])[shuffle_indices]
# m_out_train,m_out_train_1,m_out_train_2 = mapping_network.finetune(pyporch.FloatTensor(z_train), pyporch.LongTensor(y_train))
m_out_train = np.concatenate([vv[2] for vv in z_y_m])[shuffle_indices]

place = fluid.CUDAPlace(0)
batch_size = 128
with fluid.dygraph.guard(place=place):
    import core.model

    if "afhq" in input_file:
        mapping_network_ema = core.model.MappingNetwork(
            16, 64, 3)  # copy.deepcopy(mapping_network)
        out_model_fn = "../expr/checkpoints/afhq/100000_nets_ema.ckpt/mapping_network.pdparams"
        mapping_network_ema.load_state_dict(porch.load(out_model_fn))

    else:
        mapping_network_ema = core.model.MappingNetwork(
            16, 64, 2)  # copy.deepcopy(mapping_network)
        out_model_fn = "../expr/checkpoints/celeba_hq/100000_nets_ema.ckpt/mapping_network.pdparams"
        mapping_network_ema.load_state_dict(porch.load(out_model_fn))

    d_optimizer = fluid.optimizer.AdamOptimizer(
        learning_rate=lr, parameter_list=mapping_network_ema.parameters())
    from tqdm import tqdm

    mapping_network_ema.train()
    z_train_p = porch.Tensor(z_train)
    y_train_p = porch.LongTensor(y_train)
    m_out_train_p = porch.Tensor(m_out_train)
コード例 #12
0
ファイル: lpips.py プロジェクト: zzz2010/Contrib
 def _load_lpips_weights(self,pretrained_weights_fn):
     own_state_dict = self.state_dict()
     state_dict = torch.load(pretrained_weights_fn)
     self.load_state_dict(state_dict)
コード例 #13
0
def main(args):
    dgl.random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.gpu >= 0:
        torch.cuda.manual_seed(args.seed)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location="cpu")
            pretrain_args = checkpoint["opt"]
            pretrain_args.fold_idx = args.fold_idx
            pretrain_args.gpu = args.gpu
            pretrain_args.finetune = args.finetune
            pretrain_args.resume = args.resume
            pretrain_args.cv = args.cv
            pretrain_args.dataset = args.dataset
            pretrain_args.epochs = args.epochs
            pretrain_args.num_workers = args.num_workers
            if args.dataset in GRAPH_CLASSIFICATION_DSETS:
                # HACK for speeding up finetuning on graph classification tasks
                pretrain_args.num_workers = 0
            pretrain_args.batch_size = args.batch_size
            args = pretrain_args
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    args = option_update(args)
    print(args)
    if args.gpu >= 0:
        assert args.gpu is not None and torch.cuda.is_available()
        print("Use GPU: {} for training".format(args.gpu))
    assert args.positional_embedding_size % 2 == 0
    print("setting random seeds")

    mem = psutil.virtual_memory()
    print("before construct dataset", mem.used / 1024**3)
    if args.finetune:
        if args.dataset in GRAPH_CLASSIFICATION_DSETS:
            dataset = GraphClassificationDatasetLabeled(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
            labels = dataset.dataset.data.y.tolist()
        else:
            dataset = NodeClassificationDatasetLabeled(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
            labels = dataset.data.y.argmax(dim=1).tolist()

        skf = StratifiedKFold(n_splits=10,
                              shuffle=True,
                              random_state=args.seed)
        idx_list = []
        for idx in skf.split(np.zeros(len(labels)), labels):
            idx_list.append(idx)
        assert (0 <= args.fold_idx
                and args.fold_idx < 10), "fold_idx must be from 0 to 9."
        train_idx, test_idx = idx_list[args.fold_idx]
        train_dataset = torch.utils.data.Subset(dataset, train_idx)
        valid_dataset = torch.utils.data.Subset(dataset, test_idx)

    elif args.dataset == "dgl":
        train_dataset = LoadBalanceGraphDataset(
            rw_hops=args.rw_hops,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
            num_workers=args.num_workers,
            num_samples=args.num_samples,
            dgl_graphs_file="./data/small.bin",
            num_copies=args.num_copies,
        )
    else:
        if args.dataset in GRAPH_CLASSIFICATION_DSETS:
            train_dataset = GraphClassificationDataset(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
        else:
            train_dataset = NodeClassificationDataset(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )

    mem = psutil.virtual_memory()
    print("before construct dataloader", mem.used / 1024**3)
    train_loader = torch.utils.data.graph.Dataloader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        collate_fn=labeled_batcher() if args.finetune else batcher(),
        shuffle=True if args.finetune else False,
        num_workers=args.num_workers,
        worker_init_fn=None
        if args.finetune or args.dataset != "dgl" else worker_init_fn,
    )
    if args.finetune:
        valid_loader = torch.utils.data.DataLoader(
            dataset=valid_dataset,
            batch_size=args.batch_size,
            collate_fn=labeled_batcher(),
            num_workers=args.num_workers,
        )
    mem = psutil.virtual_memory()
    print("before training", mem.used / 1024**3)

    # create model and optimizer
    # n_data = train_dataset.total
    n_data = None
    import gcc.models.graph_encoder
    gcc.models.graph_encoder.final_dropout = 0  ##disable dropout
    model, model_ema = [
        GraphEncoder(
            positional_embedding_size=args.positional_embedding_size,
            max_node_freq=args.max_node_freq,
            max_edge_freq=args.max_edge_freq,
            max_degree=args.max_degree,
            freq_embedding_size=args.freq_embedding_size,
            degree_embedding_size=args.degree_embedding_size,
            output_dim=args.hidden_size,
            node_hidden_dim=args.hidden_size,
            edge_hidden_dim=args.hidden_size,
            num_layers=args.num_layer,
            num_step_set2set=args.set2set_iter,
            num_layer_set2set=args.set2set_lstm_layer,
            norm=args.norm,
            gnn_model=args.model,
            degree_input=True,
        ) for _ in range(2)
    ]

    # copy weights from `model' to `model_ema'
    if args.moco:
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    contrast = MemoryMoCo(args.hidden_size,
                          n_data,
                          args.nce_k,
                          args.nce_t,
                          use_softmax=True)
    if args.gpu >= 0:
        contrast = contrast.cuda(args.gpu)

    if args.finetune:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = NCESoftmaxLoss() if args.moco else NCESoftmaxLossNS()
        if args.gpu >= 0:
            criterion = criterion.cuda(args.gpu)
    if args.gpu >= 0:
        model = model.cuda(args.gpu)
        model_ema = model_ema.cuda(args.gpu)

    if args.finetune:
        output_layer = nn.Linear(in_features=args.hidden_size,
                                 out_features=dataset.num_classes)
        if args.gpu >= 0:
            output_layer = output_layer.cuda(args.gpu)
        output_layer_optimizer = torch.optim.Adam(
            output_layer.parameters(),
            lr=args.learning_rate,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )

        def clear_bn(m):
            classname = m.__class__.__name__
            if classname.find("BatchNorm") != -1:
                m.reset_running_stats()

        model.apply(clear_bn)

    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adagrad":
        optimizer = torch.optim.Adagrad(
            model.parameters(),
            lr=args.learning_rate,
            lr_decay=args.lr_decay_rate,
            weight_decay=args.weight_decay,
        )
    else:
        raise NotImplementedError

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if True:
        # print("=> loading checkpoint '{}'".format(args.resume))
        # checkpoint = torch.load(args.resume, map_location="cpu")
        import torch as th
        checkpoint = th.load("torch_models/ckpt_epoch_100.pth",
                             map_location=th.device('cpu'))
        torch_input_output_grad = th.load(
            "torch_models/torch_input_output_grad.pt",
            map_location=th.device('cpu'))
        from paddorch.convert_pretrain_model import load_pytorch_pretrain_model
        print("loading.............. model")
        paddle_state_dict = load_pytorch_pretrain_model(
            model, checkpoint["model"])
        model.load_state_dict(paddle_state_dict)
        print("loading.............. contrast")
        paddle_state_dict2 = load_pytorch_pretrain_model(
            contrast, checkpoint["contrast"])
        contrast.load_state_dict(paddle_state_dict2)
        print("loading.............. model_ema")
        paddle_state_dict3 = load_pytorch_pretrain_model(
            model_ema, checkpoint["model_ema"])
        if args.moco:
            model_ema.load_state_dict(paddle_state_dict3)

        print("=> loaded successfully '{}' (epoch {})".format(
            args.resume, checkpoint["epoch"]))
        del checkpoint
        if args.gpu >= 0:
            torch.cuda.empty_cache()
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate * 0.1,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )
        for _ in range(1):
            graph_q, graph_k = train_dataset[0]
            graph_q2, graph_k2 = train_dataset[1]
            graph_q, graph_k = dgl.batch([graph_q, graph_q2
                                          ]), dgl.batch([graph_k, graph_k2])

            input_output_grad = []
            input_output_grad.append([graph_q, graph_k])
            model.train()
            model_ema.eval()

            feat_q = model(graph_q)
            with torch.no_grad():
                feat_k = model_ema(graph_k)

            out = contrast(feat_q, feat_k)
            loss = criterion(out)
            optimizer.zero_grad()
            loss.backward()
            input_output_grad.append([feat_q, out, loss])
            print("loss:", loss.numpy())
            optimizer.step()
            moment_update(model, model_ema, args.alpha)
        print(
            "max diff feat_q:",
            np.max(
                np.abs(torch_input_output_grad[1][0].detach().numpy() -
                       feat_q.numpy())))
        print(
            "max diff out:",
            np.max(
                np.abs(torch_input_output_grad[1][1].detach().numpy() -
                       out.numpy())))
        print(
            "max diff loss:",
            np.max(
                np.abs(torch_input_output_grad[1][2].detach().numpy() -
                       loss.numpy())))

        name2grad = dict()
        for name, p in dict(model.named_parameters()).items():
            if p.grad is not None:
                name2grad[name] = p.grad
                torch_grad = torch_input_output_grad[2][name].numpy()

                if "linear" in name and "weight" in name:
                    torch_grad = torch_grad.T
                max_grad_diff = np.max(np.abs(p.grad - torch_grad))
                print("max grad diff:", name, max_grad_diff)
        input_output_grad.append(name2grad)
コード例 #14
0
y_train = y_train[shuffle_indices]
z_train = np.concatenate([vv[0] for vv in z_y_m])[shuffle_indices]
# m_out_train,m_out_train_1,m_out_train_2 = mapping_network.finetune(pyporch.FloatTensor(z_train), pyporch.LongTensor(y_train))
m_out_train = np.concatenate([vv[2] for vv in z_y_m])[shuffle_indices]

place = fluid.CUDAPlace(0)
batch_size = 128
with fluid.dygraph.guard(place=place):
    import core.model

    if "afhq" in input_file:
        mapping_network_ema = core.model.MappingNetwork(16, 64,
                                                        3)  # copy.deepcopy(mapping_network)
        out_model_fn = "../expr/checkpoints/afhq/100000_nets_ema.ckpt/mapping_network.pdparams"
        mapping_network_ema.load_state_dict(
            porch.load(out_model_fn))

    else:
        mapping_network_ema = core.model.MappingNetwork(16, 64,
                                                        2)  # copy.deepcopy(mapping_network)
        out_model_fn = "../expr/checkpoints/celeba_hq/100000_nets_ema.ckpt/mapping_network.pdparams"
        mapping_network_ema.load_state_dict(
            porch.load(out_model_fn))

    d_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=lr, parameter_list=mapping_network_ema.parameters())

    mapping_network_ema.train()
    z_train_p = porch.Tensor(z_train)
    y_train_p = porch.LongTensor(y_train)
    m_out_train_p = porch.Tensor(m_out_train)
    best_loss = 100000000
コード例 #15
0
ファイル: train.py プロジェクト: PaddlePaddle/Contrib
def main(args):
    dgl.random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.gpu >= 0:
        torch.cuda.manual_seed(args.seed)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location="cpu")
            pretrain_args = checkpoint["opt"]
            pretrain_args.fold_idx = args.fold_idx
            pretrain_args.gpu = args.gpu
            pretrain_args.finetune = args.finetune
            pretrain_args.resume = args.resume
            pretrain_args.cv = args.cv
            pretrain_args.dataset = args.dataset
            pretrain_args.epochs = args.epochs
            pretrain_args.num_workers = args.num_workers
            if args.dataset in GRAPH_CLASSIFICATION_DSETS:
                # HACK for speeding up finetuning on graph classification tasks
                pretrain_args.num_workers = 1
            pretrain_args.batch_size = args.batch_size
            args = pretrain_args
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    args = option_update(args)
    learning_rate = float(args.learning_rate)
    print(args)
    if args.gpu >= 0:
        assert args.gpu is not None and torch.cuda.is_available()
        print("Use GPU: {} for training".format(args.gpu))
    assert args.positional_embedding_size % 2 == 0
    print("setting random seeds")

    mem = psutil.virtual_memory()
    print("before construct dataset", mem.used / 1024**3)
    if args.finetune:
        if args.dataset in GRAPH_CLASSIFICATION_DSETS:
            dataset = GraphClassificationDatasetLabeled(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
            labels = dataset.dataset.data.y.tolist()
        else:
            dataset = NodeClassificationDatasetLabeled(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
            labels = dataset.data.y.argmax(dim=1).tolist()

        skf = StratifiedKFold(n_splits=10,
                              shuffle=True,
                              random_state=args.seed)
        idx_list = []
        for idx in skf.split(np.zeros(len(labels)), labels):
            idx_list.append(idx)
        assert (0 <= args.fold_idx
                and args.fold_idx < 10), "fold_idx must be from 0 to 9."
        train_idx, test_idx = idx_list[args.fold_idx]
        train_dataset = torch.utils.data.Subset(dataset, train_idx)
        valid_dataset = torch.utils.data.Subset(dataset, test_idx)

    elif args.dataset == "dgl":
        train_dataset = LoadBalanceGraphDataset(
            rw_hops=args.rw_hops,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
            num_workers=args.num_workers,
            num_samples=args.num_samples,
            dgl_graphs_file="./data/small.bin",
            num_copies=args.num_copies,
        )
    else:
        if args.dataset in GRAPH_CLASSIFICATION_DSETS:
            train_dataset = GraphClassificationDataset(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
        else:
            train_dataset = NodeClassificationDataset(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )

    mem = psutil.virtual_memory()
    print("before construct dataloader", mem.used / 1024**3)
    train_loader = torch.utils.data.graph.Dataloader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        collate_fn=labeled_batcher() if args.finetune else batcher(),
        shuffle=True if args.finetune else False,
        num_workers=args.num_workers,
        worker_init_fn=None
        if args.finetune or args.dataset != "dgl" else worker_init_fn,
    )
    if args.finetune:
        valid_loader = torch.utils.data.graph.Dataloader(
            dataset=valid_dataset,
            batch_size=args.batch_size,
            collate_fn=labeled_batcher(),
            num_workers=args.num_workers,
        )
    mem = psutil.virtual_memory()
    print("before training", mem.used / 1024**3)

    # create model and optimizer
    # n_data = train_dataset.total
    n_data = None

    model, model_ema = [
        GraphEncoder(
            positional_embedding_size=args.positional_embedding_size,
            max_node_freq=args.max_node_freq,
            max_edge_freq=args.max_edge_freq,
            max_degree=args.max_degree,
            freq_embedding_size=args.freq_embedding_size,
            degree_embedding_size=args.degree_embedding_size,
            output_dim=args.hidden_size,
            node_hidden_dim=args.hidden_size,
            edge_hidden_dim=args.hidden_size,
            num_layers=args.num_layer,
            num_step_set2set=args.set2set_iter,
            num_layer_set2set=args.set2set_lstm_layer,
            norm=args.norm,
            gnn_model=args.model,
            degree_input=True,
        ) for _ in range(2)
    ]

    # copy weights from `model' to `model_ema'
    if args.moco:
        # model_ema.load_state_dict(model.state_dict()) ##complete copy of model
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    contrast = MemoryMoCo(args.hidden_size,
                          n_data,
                          args.nce_k,
                          args.nce_t,
                          use_softmax=True)
    if args.gpu >= 0:
        contrast = contrast

    if args.finetune:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = NCESoftmaxLoss() if args.moco else NCESoftmaxLossNS()
        if args.gpu >= 0:
            criterion = criterion
    if args.gpu >= 0:
        model = model
        model_ema = model_ema

    import paddle
    if args.finetune:
        output_layer = nn.Linear(in_features=args.hidden_size,
                                 out_features=dataset.num_classes)
        if args.gpu >= 0:
            output_layer = output_layer
        output_layer_optimizer = torch.optim.Adam(
            output_layer.parameters(),
            lr=args.learning_rate,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
            grad_clip=paddle.nn.clip.ClipGradByValue(max=1))

        def clear_bn(m):
            classname = m.__class__.__name__
            if classname.find("BatchNorm") != -1:
                m.reset_running_stats()

        model.apply(clear_bn)

    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adam":
        if args.finetune:
            optimizer = torch.optim.Adam(
                model.parameters(),
                lr=learning_rate,
                betas=(args.beta1, args.beta2),
                weight_decay=args.weight_decay,
                grad_clip=paddle.nn.clip.ClipGradByValue(max=1),
            )
        else:
            optimizer = torch.optim.Adam(
                model.parameters(),
                lr=learning_rate,
                betas=(args.beta1, args.beta2),
                weight_decay=args.weight_decay,
                grad_clip=paddle.nn.clip.ClipGradByNorm(args.clip_norm))
    elif args.optimizer == "adagrad":
        optimizer = torch.optim.Adagrad(
            model.parameters(),
            lr=args.learning_rate,
            lr_decay=args.lr_decay_rate,
            weight_decay=args.weight_decay,
        )
    else:
        raise NotImplementedError

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if args.finetune:  ##if finetune model exists, continue resume that
            if os.path.isdir(args.model_folder + "/current.pth"):
                args.resume = args.model_folder + "/current.pth"
                print("change resume model to finetune model path:",
                      args.resume)
                ##find last end epoch
                import glob
                ckpt_epoches = glob.glob(args.model_folder +
                                         "/ckpt_epoch*.pth")
                if len(ckpt_epoches) > 0:
                    args.start_epoch = sorted([
                        int(
                            os.path.basename(x).replace(".pth", "").replace(
                                "ckpt_epoch_", "")) for x in ckpt_epoches
                    ])[-1] + 1
                    print("starting epoch:", args.start_epoch)
                    args.epochs = args.epochs + args.start_epoch - 1
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume, map_location="cpu")
        # checkpoint = torch.load(args.resume)
        # args.start_epoch = checkpoint["epoch"] + 1
        model.load_state_dict(checkpoint["model"])
        # optimizer.load_state_dict(checkpoint["optimizer"])
        contrast.load_state_dict(checkpoint["contrast"])
        if args.moco:
            model_ema.load_state_dict(checkpoint["model_ema"])

        print("=> loaded successfully '{}' ".format(args.resume))
        if args.finetune:
            if "output_layer" in checkpoint:
                output_layer.load_state_dict(checkpoint["output_layer"])
                print("loaded output layer")
        # del checkpoint
        if args.gpu >= 0:
            torch.cuda.empty_cache()

    # tensorboard
    #  logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)
    sw = LogWriter(logdir=args.tb_folder)

    import gc
    gc.enable()
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        try:
            if args.finetune:
                loss, _ = train_finetune(
                    epoch,
                    train_loader,
                    model,
                    output_layer,
                    criterion,
                    optimizer,
                    output_layer_optimizer,
                    sw,
                    args,
                )
            else:

                loss = train_moco(
                    epoch,
                    train_loader,
                    model,
                    model_ema,
                    contrast,
                    criterion,
                    optimizer,
                    sw,
                    args,
                )
        except:
            print("Error in Epoch", epoch)
            continue
        time2 = time.time()
        print("epoch {}, total time {:.2f}".format(epoch, time2 - time1))

        # save model
        if epoch % args.save_freq == 0:
            print("==> Saving...")
            state = {
                "opt": vars(args).copy(),
                "model": model.state_dict(),
                "contrast": contrast.state_dict(),
                "optimizer": optimizer.state_dict()
            }
            if args.moco:
                state["model_ema"] = model_ema.state_dict()
            if args.finetune:
                state['output_layer'] = output_layer.state_dict()
            save_file = os.path.join(
                args.model_folder,
                "ckpt_epoch_{epoch}.pth".format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            # del state

        # saving the model
        print("==> Saving...")
        state = {
            "opt": vars(args).copy(),
            "model": model.state_dict(),
            "contrast": contrast.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        if args.moco:
            state["model_ema"] = model_ema.state_dict()
        if args.finetune:
            state['output_layer'] = output_layer.state_dict()
        save_file = os.path.join(args.model_folder, "current.pth")
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                "ckpt_epoch_{epoch}.pth".format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        # del state
        if args.gpu >= 0:
            torch.cuda.empty_cache()

        if args.finetune:
            valid_loss, valid_f1 = test_finetune(epoch, valid_loader, model,
                                                 output_layer, criterion, sw,
                                                 args)
            print("epoch %d| valid f1: %.3f" % (epoch, valid_f1))

    # del model,model_ema,train_loader
    gc.collect()
    return valid_f1