Ejemplo n.º 1
0
def compute_d_loss(nets,
                   args,
                   x_real,
                   y_org,
                   y_trg,
                   z_trg=None,
                   x_ref=None,
                   masks=None):
    assert (z_trg is None) != (x_ref is None)
    # with real images
    x_real.stop_gradient = False
    out = nets.discriminator(x_real, y_org)
    loss_real = adv_loss(out, 1)
    loss_reg = r1_reg(out, x_real)

    # with fake images
    with porch.no_grad():
        if z_trg is not None:
            s_trg = nets.mapping_network(z_trg, y_trg)
        else:  # x_ref is not None
            s_trg = nets.style_encoder(x_ref, y_trg)

        x_fake = nets.generator(x_real, s_trg, masks=masks)
    out = nets.discriminator(x_fake, y_trg)
    loss_fake = adv_loss(out, 0)

    loss = porch.sum(loss_real + loss_fake + args.lambda_reg * loss_reg)
    return loss, Munch(real=loss_real.numpy().flatten()[0],
                       fake=loss_fake.numpy().flatten()[0],
                       reg=loss_reg)
Ejemplo n.º 2
0
        def __init__(self,
                     height=64,
                     width=64,
                     with_r=False,
                     with_boundary=False):
            super(AddCoordsTh, self).__init__()
            self.with_r = with_r
            self.with_boundary = with_boundary
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')

            with torch.no_grad():
                x_coords = torch.arange(height).unsqueeze(1).expand(
                    height, width).float()
                y_coords = torch.arange(width).unsqueeze(0).expand(
                    height, width).float()
                x_coords = (x_coords / (height - 1)) * 2 - 1
                y_coords = (y_coords / (width - 1)) * 2 - 1
                coords = torch.stack([x_coords, y_coords],
                                     dim=0)  # (2, height, width)

                if self.with_r:
                    rr = torch.sqrt(
                        torch.pow(x_coords, 2) +
                        torch.pow(y_coords, 2))  # (height, width)
                    rr = (rr / torch.max(rr)).unsqueeze(0)
                    coords = torch.cat([coords, rr], dim=0)

                self.coords = coords.unsqueeze(0).to(
                    device)  # (1, 2 or 3, height, width)
                self.x_coords = x_coords.to(device)
                self.y_coords = y_coords.to(device)
Ejemplo n.º 3
0
    def __init__(self,
                 num_classes=1000,
                 aux_logits=True,
                 transform_input=False,
                 inception_blocks=None):
        super(Inception3, self).__init__()
        if inception_blocks is None:
            inception_blocks = [
                BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD,
                InceptionE, InceptionAux
            ]
        assert len(inception_blocks) == 7
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

        self.aux_logits = aux_logits
        self.transform_input = transform_input
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
        if aux_logits:
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
        self.fc = nn.Linear(2048, num_classes)

        # for m in self.modules():
        for name, m in self._sub_layers.items():
            if isinstance(m, dygraph.Conv2D) or isinstance(m, dygraph.Linear):
                import scipy.stats as stats
                stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                X = stats.truncnorm(-2, 2, scale=stddev)
                values = torch.as_tensor(
                    X.rvs(np.prod(m.weight.shape)).astype("float32"))
                values = values.view(*m.weight.shape)
                with torch.no_grad():
                    fluid.layers.assign(values, m.weight)

            elif isinstance(m, dygraph.BatchNorm):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
Ejemplo n.º 4
0
def accumulate_inception_activations(sample, net, num_inception_images=50000):
  pool, logits, labels = [], [], []
  while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images:
    with torch.no_grad():
      images, labels_val = sample()
      pool_val, logits_val = net(images.astype("float32"))
      pool += [pool_val]
      logits += [F.softmax(logits_val, 1)]
      labels += [labels_val]
  return torch.cat(pool, 0), torch.cat(logits, 0), torch.cat(labels, 0)
Ejemplo n.º 5
0
def sqrt_newton_schulz(A, numIters, dtype=None):
  with torch.no_grad():
    if dtype is None:
      dtype = A.dtype
    batchSize = A.shape[0]
    dim = A.shape[1]
    normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()
    Y = torch.Tensor(A/(normA.view(batchSize, 1, 1).expand(*A.shape )))
    I = torch.Tensor(torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).astype("float32"))
    Z = torch.Tensor(torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).astype("float32"))
    for i in range(numIters):
      T = torch.Tensor(0.5*(3.0*I - Z.bmm(Y)))
      Y = Y.bmm(T)
      Z = T.bmm(Z)

    sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand(*A.shape )
  return sA
Ejemplo n.º 6
0
def test_moco(train_loader, model, opt):
    """
    one epoch training for moco
    """

    model.eval()

    emb_list = []
    for idx, batch in enumerate(train_loader):
        graph_q, graph_k = batch
        bsz = graph_q.batch_size
        graph_q.to(opt.device)
        graph_k.to(opt.device)

        with torch.no_grad():
            feat_q = model(graph_q)
            feat_k = model(graph_k)

        assert feat_q.shape == (bsz, opt.hidden_size)
        emb_list.append(((feat_q + feat_k) / 2).detach().cpu())
    return torch.cat(emb_list)
Ejemplo n.º 7
0
def power_iteration(W, u_, update=True, eps=1e-12):
    # Lists holding singular vectors and values
    Wt = torch.Tensor(W).t()
    us, vs, svs = [], [], []
    for i, u in enumerate(u_):
        # Run one step of the power iteration
        with torch.no_grad():
            if W.shape[1] == 27:
                a = 1
            v = torch.matmul(u, W)
            # if (W.shape[0]==u.shape[1])  :
            #   v = torch.matmul(u, W)
            # else:
            #   v = torch.matmul(u, Wt)
            # Run Gram-Schmidt to subtract components of all other singular vectors
            v = F.normalize(gram_schmidt(v, vs), eps=eps)
            # Add to the list
            vs += [v]
            # Update the other singular vector
            u = torch.matmul(v, Wt)
            # if (W.shape[0]!=v.shape[1]):
            #   u = torch.matmul(v, Wt  )
            # else:
            #   u = torch.matmul(v, W)
            # Run Gram-Schmidt to subtract components of all other singular vectors
            u = F.normalize(gram_schmidt(u, us), eps=eps)
            # Add to the list
            us += [u]
            if update:
                torch.copy(u, u_[i])
                # u_[i][:] = u
        # Compute this singular value and add it to the list
        svs += [torch.squeeze(torch.matmul(torch.matmul(v, Wt), u.t()))]
        # if (W.shape[0]!=v.shape[1]):
        #   svs += [torch.squeeze(torch.matmul(torch.matmul(v, Wt  ), u.t() ))]
        # else:
        #   svs += [torch.squeeze(torch.matmul(torch.matmul(v, W), u.t()))]
        #svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
    return svs, us, vs
def run(config):
    # Get loader
    config['drop_last'] = False
    loaders = utils.get_data_loaders(**config)

    # Load inception net
    net = inception_utils.load_inception_net(parallel=config['parallel'])
    pool, logits, labels = [], [], []
    device = 'cuda'
    for i, (x, y) in enumerate(tqdm(loaders[0])):
        x = x.to(device)
        with torch.no_grad():
            pool_val, logits_val = net(x)
            pool += [np.asarray(pool_val)]
            logits += [np.asarray(F.softmax(logits_val, 1))]
            labels += [np.asarray(y)]

    pool, logits, labels = [
        np.concatenate(item, 0) for item in [pool, logits, labels]
    ]
    # uncomment to save pool, logits, and labels to disk
    # print('Saving pool, logits, and labels to disk...')
    # np.savez(config['dataset']+'_inception_activations.npz',
    #           {'pool': pool, 'logits': logits, 'labels': labels})
    # Calculate inception metrics and report them
    print('Calculating inception metrics...')
    IS_mean, IS_std = inception_utils.calculate_inception_score(logits)
    print('Training data from dataset %s has IS of %5.5f +/- %5.5f' %
          (config['dataset'], IS_mean, IS_std))
    # Prepare mu and sigma, save to disk. Remove "hdf5" by default
    # (the FID code also knows to strip "hdf5")
    print('Calculating means and covariances...')
    mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
    print('Saving calculated means and covariances to disk...')
    np.savez(config['dataset'].strip('_hdf5') + '_inception_moments.npz', **{
        'mu': mu,
        'sigma': sigma
    })
Ejemplo n.º 9
0
 def W_(self):
     self.training = True
     if isinstance(self, SNLinear):
         W_mat = torch.Tensor(self.weight).t(
         )  ##linear layer weight is different from pytorch weight, need to transpose
     else:
         W_mat = torch.Tensor(self.weight).view(self.weight.shape[0], -1)
     if self.transpose:
         W_mat = W_mat.t()
     # Apply num_itrs power iterations
     for _ in range(self.num_itrs):
         svs, us, vs = power_iteration(W_mat,
                                       self.u,
                                       update=self.training,
                                       eps=self.eps)
     # Update the svs
     if self.training:
         with torch.no_grad(
         ):  # Make sure to do this in a no_grad() context or you'll get memory leaks!
             for i, sv in enumerate(svs):
                 torch.copy(sv, self.sv[i])
                 # self.sv[i][:] = sv
     return self.weight / svs[0]
Ejemplo n.º 10
0
def test_finetune(epoch, valid_loader, model, output_layer, criterion, sw,
                  opt):
    n_batch = len(valid_loader)
    model.eval()
    output_layer.eval()

    epoch_loss_meter = AverageMeter()
    epoch_f1_meter = AverageMeter()

    for idx, batch in enumerate(valid_loader):
        graph_q, y = batch

        bsz = graph_q.batch_size

        # ===================forward=====================

        with torch.no_grad():
            feat_q = model(graph_q)
            assert feat_q.shape == (graph_q.batch_size, opt.hidden_size)
            out = output_layer(feat_q)
        loss = torch.convertTensor(criterion(out, y))

        preds = out.argmax(dim=1)
        f1 = f1_score(y.cpu().numpy(), preds.cpu().numpy(), average="micro")

        # ===================meters=====================
        epoch_loss_meter.update(loss.item(), bsz)
        epoch_f1_meter.update(f1, bsz)

    global_step = (epoch + 1) * n_batch
    sw.add_scalar("ft_loss/valid", epoch_loss_meter.avg, global_step)
    sw.add_scalar("ft_f1/valid", epoch_f1_meter.avg, global_step)
    print(opt.model_folder)
    print(
        f"Epoch {epoch}, loss {epoch_loss_meter.avg:.3f}, f1 {epoch_f1_meter.avg:.3f}"
    )
    return epoch_loss_meter.avg, epoch_f1_meter.avg
Ejemplo n.º 11
0
def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y, state_dict, config,
                    experiment_name):
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, None, G_ema if config['ema'] else None)
    # Save an additional copy to mitigate accidental corruption if process
    # is killed during a save (it's happened to me before -.-)
    if config['num_save_copies'] > 0:
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name, 'copy%d' % state_dict['save_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_num'] = (state_dict['save_num'] +
                                  1) % config['num_save_copies']

    # Use EMA G for samples or non-EMA?
    which_G = G_ema if config['ema'] and config['use_ema'] else G

    # Accumulate standing statistics?
    if config['accumulate_stats']:
        print("accumulate_stats")
        utils.accumulate_standing_stats(
            G_ema if config['ema'] and config['use_ema'] else G, z_, y_,
            config['n_classes'], config['num_standing_accumulations'])

    # Save a random sample sheet with fixed z and y
    with torch.no_grad():
        print("Save a random sample sheet with fixed z and y  ")
        if config['parallel']:
            fixed_Gz = nn.parallel.data_parallel(
                which_G, (fixed_z, which_G.shared(fixed_y)))
        else:
            fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y))
    if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
        os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
    image_filename = '%s/%s/fixed_samples%d.jpg' % (
        config['samples_root'], experiment_name, state_dict['itr'])
    torchvision.utils.save_image(fixed_Gz,
                                 image_filename,
                                 nrow=int(fixed_Gz.shape[0]**0.5),
                                 normalize=True)
    # For now, every time we save, also save sample sheets
    utils.sample_sheet(
        which_G,
        classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
        num_classes=config['n_classes'],
        samples_per_class=10,
        parallel=config['parallel'],
        samples_root=config['samples_root'],
        experiment_name=experiment_name,
        folder_number=state_dict['itr'],
        z_=z_)
    # Also save interp sheets
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
        utils.interp_sheet(which_G,
                           num_per_sheet=16,
                           num_midpoints=8,
                           num_classes=config['n_classes'],
                           parallel=config['parallel'],
                           samples_root=config['samples_root'],
                           experiment_name=experiment_name,
                           folder_number=state_dict['itr'],
                           sheet_number=0,
                           fix_z=fix_z,
                           fix_y=fix_y,
                           device='cuda')
Ejemplo n.º 12
0
def run(config):
    # Prepare state dict, which holds things like epoch # and itr #
    state_dict = {
        'itr': 0,
        'epoch': 0,
        'save_num': 0,
        'save_best_num': 0,
        'best_IS': 0,
        'best_FID': 999999,
        'config': config
    }

    # Optionally, get the configuration from the state dict. This allows for
    # recovery of the config provided only a state dict and experiment name,
    # and can be convenient for writing less verbose sample shell scripts.
    if config['config_from_name']:
        utils.load_weights(None,
                           None,
                           state_dict,
                           config['weights_root'],
                           config['experiment_name'],
                           config['load_weights'],
                           None,
                           strict=False,
                           load_optim=False)
        # Ignore items which we might want to overwrite from the command line
        for item in state_dict['config']:
            if item not in [
                    'z_var', 'base_root', 'batch_size', 'G_batch_size',
                    'use_ema', 'G_eval_mode'
            ]:
                config[item] = state_dict['config'][item]

    # update config (see train.py for explanation)
    config['resolution'] = utils.imsize_dict[config['dataset']]
    config['n_classes'] = utils.nclass_dict[config['dataset']]
    config['G_activation'] = utils.activation_dict[config['G_nl']]
    config['D_activation'] = utils.activation_dict[config['D_nl']]
    config = utils.update_config_roots(config)
    config['skip_init'] = True
    config['no_optim'] = True
    device = 'cuda'

    # Seed RNG
    utils.seed_rng(config['seed'])

    # Import the model--this line allows us to dynamically select different files.
    model = __import__(config['model'])
    experiment_name = (config['experiment_name'] if config['experiment_name']
                       else utils.name_from_config(config))
    print('Experiment name is %s' % experiment_name)

    G = model.Generator(**config)
    utils.count_parameters(G)

    # Load weights
    print('Loading weights...')
    # Here is where we deal with the ema--load ema weights or load normal weights
    utils.load_weights(G if not (config['use_ema']) else None,
                       None,
                       state_dict,
                       config['weights_root'],
                       experiment_name,
                       config['load_weights'],
                       G if config['ema'] and config['use_ema'] else None,
                       strict=False,
                       load_optim=False)
    # Update batch size setting used for G
    G_batch_size = max(config['G_batch_size'], config['batch_size'])
    z_, y_ = utils.prepare_z_y(G_batch_size,
                               G.dim_z,
                               config['n_classes'],
                               device=device,
                               fp16=config['G_fp16'],
                               z_var=config['z_var'])

    if config['G_eval_mode']:
        print('Putting G in eval mode..')
        G.eval()
    else:
        print('G is in %s mode...' % ('training' if G.training else 'eval'))

    #Sample function
    sample = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)
    if config['accumulate_stats']:
        print('Accumulating standing stats across %d accumulations...' %
              config['num_standing_accumulations'])
        utils.accumulate_standing_stats(G, z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Sample a number of images and save them to an NPZ, for use with TF-Inception
    if config['sample_npz']:
        # Lists to hold images and labels for images
        x, y = [], []
        print('Sampling %d images and saving them to npz...' %
              config['sample_num_npz'])
        for i in trange(
                int(np.ceil(config['sample_num_npz'] / float(G_batch_size)))):
            with torch.no_grad():
                images, labels = sample()
            x += [np.uint8(255 * (images.numpy() + 1) / 2.)]
            y += [labels.numpy()]
        x = np.concatenate(x, 0)[:config['sample_num_npz']]
        y = np.concatenate(y, 0)[:config['sample_num_npz']]
        print('Images shape: %s, Labels shape: %s' % (x.shape, y.shape))
        npz_filename = '%s/%s/samples.npz' % (config['samples_root'],
                                              experiment_name)
        print('Saving npz to %s...' % npz_filename)
        np.savez(npz_filename, **{'x': x, 'y': y})

    # Prepare sample sheets
    if config['sample_sheets']:
        print('Preparing conditional sample sheets...')
        utils.sample_sheet(
            G,
            classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
            num_classes=config['n_classes'],
            samples_per_class=10,
            parallel=config['parallel'],
            samples_root=config['samples_root'],
            experiment_name=experiment_name,
            folder_number=config['sample_sheet_folder_num'],
            z_=z_,
        )
    # Sample interp sheets
    if config['sample_interps']:
        print('Preparing interp sheets...')
        for fix_z, fix_y in zip([False, False, True], [False, True, False]):
            utils.interp_sheet(G,
                               num_per_sheet=16,
                               num_midpoints=8,
                               num_classes=config['n_classes'],
                               parallel=config['parallel'],
                               samples_root=config['samples_root'],
                               experiment_name=experiment_name,
                               folder_number=config['sample_sheet_folder_num'],
                               sheet_number=0,
                               fix_z=fix_z,
                               fix_y=fix_y,
                               device='cuda')
    # Sample random sheet
    if config['sample_random']:
        print('Preparing random sample sheet...')
        images, labels = sample()
        torchvision.utils.save_image(images.astype("float32"),
                                     '%s/%s/random_samples.jpg' %
                                     (config['samples_root'], experiment_name),
                                     nrow=int(G_batch_size**0.5),
                                     normalize=True)

    # Get Inception Score and FID
    get_inception_metrics = inception_utils.prepare_inception_metrics(
        config['dataset'], config['parallel'], config['no_fid'])

    # Prepare a simple function get metrics that we use for trunc curves
    def get_metrics():
        sample = functools.partial(utils.sample,
                                   G=G,
                                   z_=z_,
                                   y_=y_,
                                   config=config)
        IS_mean, IS_std, FID = get_inception_metrics(
            sample,
            config['num_inception_images'],
            num_splits=10,
            prints=False)
        # Prepare output string
        outstring = 'Using %s weights ' % ('ema'
                                           if config['use_ema'] else 'non-ema')
        outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else
                                       'training')
        outstring += 'with noise variance %3.3f, ' % z_.var
        outstring += 'over %d images, ' % config['num_inception_images']
        if config['accumulate_stats'] or not config['G_eval_mode']:
            outstring += 'with batch size %d, ' % G_batch_size
        if config['accumulate_stats']:
            outstring += 'using %d standing stat accumulations, ' % config[
                'num_standing_accumulations']
        outstring += 'Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' % (
            state_dict['itr'], IS_mean, IS_std, FID)
        print(outstring)

    if config['sample_inception_metrics']:
        print('Calculating Inception metrics...')
        get_metrics()

    # Sample truncation curve stuff. This is basically the same as the inception metrics code
    if config['sample_trunc_curves']:
        start, step, end = [
            float(item) for item in config['sample_trunc_curves'].split('_')
        ]
        print(
            'Getting truncation values for variance in range (%3.3f:%3.3f:%3.3f)...'
            % (start, step, end))
        for var in np.arange(start, end + step, step):
            z_.var = var
            # Optionally comment this out if you want to run with standing stats
            # accumulated at one z variance setting
            if config['accumulate_stats']:
                utils.accumulate_standing_stats(
                    G, z_, y_, config['n_classes'],
                    config['num_standing_accumulations'])
            get_metrics()
Ejemplo n.º 13
0
def main(args):
    dgl.random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.gpu >= 0:
        torch.cuda.manual_seed(args.seed)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location="cpu")
            pretrain_args = checkpoint["opt"]
            pretrain_args.fold_idx = args.fold_idx
            pretrain_args.gpu = args.gpu
            pretrain_args.finetune = args.finetune
            pretrain_args.resume = args.resume
            pretrain_args.cv = args.cv
            pretrain_args.dataset = args.dataset
            pretrain_args.epochs = args.epochs
            pretrain_args.num_workers = args.num_workers
            if args.dataset in GRAPH_CLASSIFICATION_DSETS:
                # HACK for speeding up finetuning on graph classification tasks
                pretrain_args.num_workers = 0
            pretrain_args.batch_size = args.batch_size
            args = pretrain_args
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    args = option_update(args)
    print(args)
    if args.gpu >= 0:
        assert args.gpu is not None and torch.cuda.is_available()
        print("Use GPU: {} for training".format(args.gpu))
    assert args.positional_embedding_size % 2 == 0
    print("setting random seeds")

    mem = psutil.virtual_memory()
    print("before construct dataset", mem.used / 1024**3)
    if args.finetune:
        if args.dataset in GRAPH_CLASSIFICATION_DSETS:
            dataset = GraphClassificationDatasetLabeled(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
            labels = dataset.dataset.data.y.tolist()
        else:
            dataset = NodeClassificationDatasetLabeled(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
            labels = dataset.data.y.argmax(dim=1).tolist()

        skf = StratifiedKFold(n_splits=10,
                              shuffle=True,
                              random_state=args.seed)
        idx_list = []
        for idx in skf.split(np.zeros(len(labels)), labels):
            idx_list.append(idx)
        assert (0 <= args.fold_idx
                and args.fold_idx < 10), "fold_idx must be from 0 to 9."
        train_idx, test_idx = idx_list[args.fold_idx]
        train_dataset = torch.utils.data.Subset(dataset, train_idx)
        valid_dataset = torch.utils.data.Subset(dataset, test_idx)

    elif args.dataset == "dgl":
        train_dataset = LoadBalanceGraphDataset(
            rw_hops=args.rw_hops,
            restart_prob=args.restart_prob,
            positional_embedding_size=args.positional_embedding_size,
            num_workers=args.num_workers,
            num_samples=args.num_samples,
            dgl_graphs_file="./data/small.bin",
            num_copies=args.num_copies,
        )
    else:
        if args.dataset in GRAPH_CLASSIFICATION_DSETS:
            train_dataset = GraphClassificationDataset(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )
        else:
            train_dataset = NodeClassificationDataset(
                dataset=args.dataset,
                rw_hops=args.rw_hops,
                subgraph_size=args.subgraph_size,
                restart_prob=args.restart_prob,
                positional_embedding_size=args.positional_embedding_size,
            )

    mem = psutil.virtual_memory()
    print("before construct dataloader", mem.used / 1024**3)
    train_loader = torch.utils.data.graph.Dataloader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        collate_fn=labeled_batcher() if args.finetune else batcher(),
        shuffle=True if args.finetune else False,
        num_workers=args.num_workers,
        worker_init_fn=None
        if args.finetune or args.dataset != "dgl" else worker_init_fn,
    )
    if args.finetune:
        valid_loader = torch.utils.data.DataLoader(
            dataset=valid_dataset,
            batch_size=args.batch_size,
            collate_fn=labeled_batcher(),
            num_workers=args.num_workers,
        )
    mem = psutil.virtual_memory()
    print("before training", mem.used / 1024**3)

    # create model and optimizer
    # n_data = train_dataset.total
    n_data = None
    import gcc.models.graph_encoder
    gcc.models.graph_encoder.final_dropout = 0  ##disable dropout
    model, model_ema = [
        GraphEncoder(
            positional_embedding_size=args.positional_embedding_size,
            max_node_freq=args.max_node_freq,
            max_edge_freq=args.max_edge_freq,
            max_degree=args.max_degree,
            freq_embedding_size=args.freq_embedding_size,
            degree_embedding_size=args.degree_embedding_size,
            output_dim=args.hidden_size,
            node_hidden_dim=args.hidden_size,
            edge_hidden_dim=args.hidden_size,
            num_layers=args.num_layer,
            num_step_set2set=args.set2set_iter,
            num_layer_set2set=args.set2set_lstm_layer,
            norm=args.norm,
            gnn_model=args.model,
            degree_input=True,
        ) for _ in range(2)
    ]

    # copy weights from `model' to `model_ema'
    if args.moco:
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    contrast = MemoryMoCo(args.hidden_size,
                          n_data,
                          args.nce_k,
                          args.nce_t,
                          use_softmax=True)
    if args.gpu >= 0:
        contrast = contrast.cuda(args.gpu)

    if args.finetune:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = NCESoftmaxLoss() if args.moco else NCESoftmaxLossNS()
        if args.gpu >= 0:
            criterion = criterion.cuda(args.gpu)
    if args.gpu >= 0:
        model = model.cuda(args.gpu)
        model_ema = model_ema.cuda(args.gpu)

    if args.finetune:
        output_layer = nn.Linear(in_features=args.hidden_size,
                                 out_features=dataset.num_classes)
        if args.gpu >= 0:
            output_layer = output_layer.cuda(args.gpu)
        output_layer_optimizer = torch.optim.Adam(
            output_layer.parameters(),
            lr=args.learning_rate,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )

        def clear_bn(m):
            classname = m.__class__.__name__
            if classname.find("BatchNorm") != -1:
                m.reset_running_stats()

        model.apply(clear_bn)

    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adagrad":
        optimizer = torch.optim.Adagrad(
            model.parameters(),
            lr=args.learning_rate,
            lr_decay=args.lr_decay_rate,
            weight_decay=args.weight_decay,
        )
    else:
        raise NotImplementedError

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if True:
        # print("=> loading checkpoint '{}'".format(args.resume))
        # checkpoint = torch.load(args.resume, map_location="cpu")
        import torch as th
        checkpoint = th.load("torch_models/ckpt_epoch_100.pth",
                             map_location=th.device('cpu'))
        torch_input_output_grad = th.load(
            "torch_models/torch_input_output_grad.pt",
            map_location=th.device('cpu'))
        from paddorch.convert_pretrain_model import load_pytorch_pretrain_model
        print("loading.............. model")
        paddle_state_dict = load_pytorch_pretrain_model(
            model, checkpoint["model"])
        model.load_state_dict(paddle_state_dict)
        print("loading.............. contrast")
        paddle_state_dict2 = load_pytorch_pretrain_model(
            contrast, checkpoint["contrast"])
        contrast.load_state_dict(paddle_state_dict2)
        print("loading.............. model_ema")
        paddle_state_dict3 = load_pytorch_pretrain_model(
            model_ema, checkpoint["model_ema"])
        if args.moco:
            model_ema.load_state_dict(paddle_state_dict3)

        print("=> loaded successfully '{}' (epoch {})".format(
            args.resume, checkpoint["epoch"]))
        del checkpoint
        if args.gpu >= 0:
            torch.cuda.empty_cache()
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.learning_rate * 0.1,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay,
        )
        for _ in range(1):
            graph_q, graph_k = train_dataset[0]
            graph_q2, graph_k2 = train_dataset[1]
            graph_q, graph_k = dgl.batch([graph_q, graph_q2
                                          ]), dgl.batch([graph_k, graph_k2])

            input_output_grad = []
            input_output_grad.append([graph_q, graph_k])
            model.train()
            model_ema.eval()

            feat_q = model(graph_q)
            with torch.no_grad():
                feat_k = model_ema(graph_k)

            out = contrast(feat_q, feat_k)
            loss = criterion(out)
            optimizer.zero_grad()
            loss.backward()
            input_output_grad.append([feat_q, out, loss])
            print("loss:", loss.numpy())
            optimizer.step()
            moment_update(model, model_ema, args.alpha)
        print(
            "max diff feat_q:",
            np.max(
                np.abs(torch_input_output_grad[1][0].detach().numpy() -
                       feat_q.numpy())))
        print(
            "max diff out:",
            np.max(
                np.abs(torch_input_output_grad[1][1].detach().numpy() -
                       out.numpy())))
        print(
            "max diff loss:",
            np.max(
                np.abs(torch_input_output_grad[1][2].detach().numpy() -
                       loss.numpy())))

        name2grad = dict()
        for name, p in dict(model.named_parameters()).items():
            if p.grad is not None:
                name2grad[name] = p.grad
                torch_grad = torch_input_output_grad[2][name].numpy()

                if "linear" in name and "weight" in name:
                    torch_grad = torch_grad.T
                max_grad_diff = np.max(np.abs(p.grad - torch_grad))
                print("max grad diff:", name, max_grad_diff)
        input_output_grad.append(name2grad)
Ejemplo n.º 14
0
def train_moco(epoch, train_loader, model, model_ema, contrast, criterion,
               optimizer, sw, opt):
    """
    one epoch training for moco
    """

    n_batch = train_loader.dataset.total // opt.batch_size
    no_update_debug = False
    if no_update_debug:
        model.eval()
        contrast.eval()
    else:
        model.train()
    model_ema.eval()

    def set_bn_train(m):
        classname = m.__class__.__name__
        if classname.find("BatchNorm") != -1:
            m.train()

    if not no_update_debug:
        model_ema.apply(set_bn_train)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    epoch_loss_meter = AverageMeter()
    prob_meter = AverageMeter()
    graph_size = AverageMeter()
    gnorm_meter = AverageMeter()
    max_num_nodes = 0
    max_num_edges = 0

    end = time.time()
    if no_update_debug:
        graph_q, graph_k = train_loader.dataset[0]
        graph_q2, graph_k2 = train_loader.dataset[1]
        graph_q, graph_k = dgl.batch([graph_q, graph_q2
                                      ]), dgl.batch([graph_k, graph_k2])

    for idx, batch in enumerate(train_loader):
        data_time.update(time.time() - end)
        if not no_update_debug:
            graph_q, graph_k = batch

        # graph_q.to(torch.device(opt.gpu))
        # graph_k.to(torch.device(opt.gpu))
        ##inject testing

        bsz = graph_q.batch_size

        if opt.moco:
            # ===================Moco forward=====================
            feat_q = model(graph_q)
            with torch.no_grad():
                feat_k = model_ema(graph_k)

            out = contrast(feat_q, feat_k)

            prob = out[:, 0].mean()
        else:
            # ===================Negative sampling forward=====================
            feat_q = model(graph_q)
            feat_k = model(graph_k)

            out = torch.matmul(feat_k, feat_q.t()) / opt.nce_t
            prob = out[range(graph_q.batch_size),
                       range(graph_q.batch_size)].mean()

        assert feat_q.shape == (graph_q.batch_size, opt.hidden_size)

        # ===================backward=====================
        optimizer.zero_grad()
        loss = criterion(out)
        #clip before the backward
        # [torch.nn.utils.clip_by_norm(p, opt.clip_norm) for p in model.parameters() ]

        if not no_update_debug:
            loss.backward()

        grad_norm = clip_grad_norm(model.parameters(), 0)

        global_step = epoch * n_batch + idx
        lr_this_step = opt.learning_rate * warmup_linear(
            global_step / (opt.epochs * n_batch), 0.1)
        if lr_this_step is not None:
            optimizer.set_lr(lr_this_step)

        # for param_group in optimizer.param_groups:
        #     param_group["lr"] = lr_this_step
        optimizer.step()
        # if not no_update_debug:
        #     optimizer.minimize(loss)
        if no_update_debug:
            print(loss.item())
        # ===================meters=====================
        loss_meter.update(loss.item(), bsz)
        epoch_loss_meter.update(loss.item(), bsz)
        prob_meter.update(prob.item(), bsz)
        graph_size.update(
            (graph_q.number_of_nodes() + graph_k.number_of_nodes()) / 2.0 /
            bsz, 2 * bsz)

        gnorm_meter.update(grad_norm, 1)
        max_num_nodes = max(max_num_nodes, graph_q.number_of_nodes())
        max_num_edges = max(max_num_edges, graph_q.number_of_edges())

        if opt.moco:
            if not no_update_debug:
                moment_update(model, model_ema, opt.alpha)

        batch_time.update(time.time() - end)
        end = time.time()
        # del graph_q, graph_k, feat_q, feat_k

        # print info
        if (idx + 1) % opt.print_freq == 0:
            mem = psutil.virtual_memory()
            #  print(f'{idx:8} - {mem.percent:5} - {mem.free/1024**3:10.2f} - {mem.available/1024**3:10.2f} - {mem.used/1024**3:10.2f}')
            #  mem_used.append(mem.used/1024**3)
            print("Train: [{0}][{1}/{2}]\t"
                  "BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                  "DT {data_time.val:.3f} ({data_time.avg:.3f})\t"
                  "loss {loss.val:.3f} ({loss.avg:.3f})\t"
                  "prob {prob.val:.3f} ({prob.avg:.3f})\t"
                  "GS {graph_size.val:.3f} ({graph_size.avg:.3f})\t"
                  "mem {mem:.3f}".format(
                      epoch,
                      idx + 1,
                      n_batch,
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=loss_meter,
                      prob=prob_meter,
                      graph_size=graph_size,
                      mem=mem.used / 1024**3,
                  ))
            #  print(out[0].abs().max())

        # tensorboard logger
        if (idx + 1) % opt.tb_freq == 0:
            global_step = epoch * n_batch + idx
            sw.add_scalar("moco_loss", loss_meter.avg, global_step)
            sw.add_scalar("moco_prob", prob_meter.avg, global_step)
            sw.add_scalar("graph_size", graph_size.avg, global_step)
            sw.add_scalar("graph_size/max", max_num_nodes, global_step)
            sw.add_scalar("graph_size/max_edges", max_num_edges, global_step)
            sw.add_scalar("gnorm", gnorm_meter.avg, global_step)
            sw.add_scalar("learning_rate", optimizer.param_groups[0]["lr"],
                          global_step)
            loss_meter.reset()
            prob_meter.reset()
            graph_size.reset()
            gnorm_meter.reset()
            max_num_nodes, max_num_edges = 0, 0
    return epoch_loss_meter.avg