def _check(vec, base):
     for b in base:
         vec = utils.gram_schmidt(vec, b)
         norm = np.sqrt(np.vdot(vec, vec))
         if norm < self.prec:
             return None
     return vec
Beispiel #2
0
def train(args):
    # Dataset
    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
    train_dataset, valid_dataset = create_train_and_valid_dataset(
        dataset=args.dataset,
        dirpath=args.data_dir,
        tokenizer=tokenizer,
        num_train_data=args.num_train_data,
        augmentation=args.data_augment,
    )

    # Loader
    collate_fn = CollateFn(tokenizer, args.max_length)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=256,
                              shuffle=False,
                              collate_fn=collate_fn)
    plot_loader = DataLoader(valid_dataset,
                             batch_size=4,
                             shuffle=False,
                             collate_fn=collate_fn)

    # Device
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
        logging.info("CUDA DEVICE: %d" % torch.cuda.current_device())
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model
    model = create_model(
        augment=args.mix_strategy,
        mixup_layer=args.m_layer,
        d_layer=args.d_layer,
        n_class=train_dataset.n_class,
        n_layer=12,
        drop_prob=args.drop_prob,
    )
    model.load()  # Load BERT pretrained weight
    model.to(device)

    # Criterion
    if args.mix_strategy == "nonlinearmix":
        criterion = None
    else:
        criterion = nn.CrossEntropyLoss()

    # Optimizer
    if args.mix_strategy == "none":
        optimizers = [
            optim.Adam(model.embedding_model.parameters(), lr=args.lr),
            optim.Adam(model.classifier.parameters(), lr=1e-3),
        ]
    elif args.mix_strategy == "tmix":
        optimizers = [
            optim.Adam(model.mix_model.embedding_model.parameters(),
                       lr=args.lr),
            optim.Adam(model.classifier.parameters(), lr=1e-3),
        ]
    elif args.mix_strategy == "nonlinearmix":
        optimizers = [
            optim.Adam(model.mix_model.embedding_model.parameters(),
                       lr=args.lr),
            optim.Adam(model.mix_model.policy_mapping_f.parameters(),
                       lr=args.lr),
            optim.Adam(model.classifier.parameters(), lr=1e-3),
            optim.Adam([model.label_matrix], lr=1e-3),
        ]
    elif args.mix_strategy == "mixuptransformer":
        optimizers = [
            optim.Adam(model.embedding_model.parameters(), lr=args.lr),
            optim.Adam(model.classifier.parameters(), lr=1e-3),
        ]
    elif args.mix_strategy == "oommix":
        optimizers = [
            optim.Adam(model.mix_model.embedding_model.parameters(),
                       lr=args.lr),
            optim.Adam(model.mix_model.embedding_generator.parameters(),
                       lr=args.lr),
            optim.Adam(model.mix_model.manifold_discriminator.parameters(),
                       lr=args.lr),
            optim.Adam(model.classifier.parameters(), lr=1e-3),
        ]

    # Scheduler
    schedulers = [
        optim.lr_scheduler.LambdaLR(optimizer, lambda x: min(x / 1000, 1))
        for optimizer in optimizers
    ]

    # Writer
    writers = {
        "tensorboard": SummaryWriter(args.out_dir),
        "gamma": open(os.path.join(args.out_dir, "gamma.csv"), "w"),
        "scalar": open(os.path.join(args.out_dir, "scalar.csv"), "w"),
    }

    step, best_acc, patience = 0, 0, 0
    model.train()
    for optimizer in optimizers:
        optimizer.zero_grad()
    for epoch in range(1, args.epoch + 1):
        for batch in train_loader:
            step += 1
            input_ids = batch["inputs"]["input_ids"].to(device)
            attention_mask = batch["inputs"]["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            if args.mix_strategy == "none":
                loss = calculate_normal_loss(model, input_ids, attention_mask,
                                             labels, epoch, step)
            elif args.mix_strategy == "tmix":
                loss = calculate_tmix_loss(model, input_ids, attention_mask,
                                           labels, args.alpha, epoch, step)
            elif args.mix_strategy == "nonlinearmix":
                loss = calculate_nonlinearmix_loss(model, input_ids,
                                                   attention_mask, labels,
                                                   args.alpha, epoch, step)
            elif args.mix_strategy == "mixuptransformer":
                loss = calculate_mixuptransformer_loss(model, criterion,
                                                       input_ids,
                                                       attention_mask, labels,
                                                       epoch, step)
            elif args.mix_strategy == "oommix":
                loss, mani_loss = calculate_oommix_loss(
                    model,
                    criterion,
                    input_ids,
                    attention_mask,
                    labels,
                    epoch,
                    step,
                    writers["gamma"],
                )
                # Order is important! Gradient for discriminator and generator
                (args.coeff_intr * mani_loss).backward(retain_graph=True)
                optimizers[0].zero_grad()
                # Order is important! Gradient for model and generator
                ((1 - args.coeff_intr) * loss).backward()
            if step % 5 == 0:
                writers["scalar"].write(
                    "%d,train loss,%.4f\n" %
                    (int(datetime.now().timestamp()), loss.item()))
                if args.mix_strategy == "oommix":
                    writers["scalar"].write(
                        "%d,manifold classification loss,%.4f\n" %
                        (int(datetime.now().timestamp()), mani_loss.item()))
            for optimizer in optimizers:
                optimizer.step()
                optimizer.zero_grad()
            for scheduler in schedulers:
                scheduler.step()
            if args.mix_strategy == "nonlinearmix":
                # Apply gram schmidt
                with torch.no_grad():
                    gs = (torch.from_numpy(
                        gram_schmidt(
                            model.label_matrix.t().cpu().numpy())).t().to(
                                device))
                    model.label_matrix.copy_(gs)

            if step % args.eval_every == 0:
                acc = evaluate_model(model, valid_loader, device)
                writers["scalar"].write("%d,valid acc,%.4f\n" %
                                        (int(datetime.now().timestamp()), acc))
                if best_acc < acc:
                    best_acc = acc
                    patience = 0
                    torch.save(model.state_dict(),
                               os.path.join(args.out_dir, "model.pth"))
                else:
                    patience += 1
                logging.info("Accuracy: %.4f, Best accuracy: %.4f" %
                             (acc, best_acc))
                if patience == args.patience:
                    break
        if patience == args.patience:
            break
    for w in writers.values():
        w.close()
Beispiel #3
0
 def calc_cg_new_(self, groups, p):
     self.cgnames = []
     self.cgind = []
     self.cg = []
     if groups is None:
         return
     g = groups[p]
     multi = 0
     dim1 = self.gamma1.shape[1]
     dim2 = self.gamma2.shape[1]
     dim12 = dim1 * dim2
     coeff = np.zeros((len(self.indices), ), dtype=complex)
     lind = []
     for indir, ir in enumerate(g.irreps):
         multi = 0
         lcoeffs = []
         dim = ir.dim
         mup = 0
         # loop over all column index combinations that conserve the COM momentum
         for mu1, mu2 in self.indices:
             #for mup, (mu1, mu2) in it.product(range(dim), self.indices):
             # loop over the row of the final irrep
             for mu in range(dim):
                 if mu != mup:
                     continue
                 coeff.fill(0.)
                 # loop over all combinations of rows of the induced
                 # representations
                 for ind1, (mu1p, mu2p) in enumerate(self.indices):
                     for i in range(g.order):
                         #tmp = ir.mx[i][mu, mu].conj()
                         tmp = ir.mx[i][mu, mup].conj()
                         look = g.lelements[i]
                         tmp *= self.gamma1[look, mu1p, mu1]
                         tmp *= self.gamma2[look, mu2p, mu2]
                         coeff[ind1] += tmp
                 coeff *= float(dim) / g.order
                 ncoeff = np.sqrt(np.vdot(coeff, coeff))
                 # if norm is 0, try next combination of mu', mu1, mu2
                 if ncoeff < self.prec:
                     continue
                 else:
                     coeff /= ncoeff
                 # orthogonalize w.r.t. already found vectors of same irrep
                 for vec in lcoeffs:
                     coeff = utils.gram_schmidt(coeff, vec, prec=self.prec)
                     ncoeff = np.sqrt(np.vdot(coeff, coeff))
                     # if zero vector, try next combination of mu', mu1, mu2
                     if ncoeff < self.prec:
                         break
                 if ncoeff < self.prec:
                     continue
                 # orthogonalize w.r.t. already found vectors of other irreps
                 for lcg in self.cg:
                     for vec in lcg:
                         coeff = utils.gram_schmidt(coeff,
                                                    vec,
                                                    prec=self.prec)
                         ncoeff = np.sqrt(np.vdot(coeff, coeff))
                         # if zero vector, try next combination of mu', mu1, mu2
                         if ncoeff < self.prec:
                             break
                     if ncoeff < self.prec:
                         break
                 if ncoeff > self.prec:
                     lcoeffs.append(coeff.copy())
                     lind.append((mu, mup, mu1, mu2))
                     multi += 1
         if multi > 0:
             print("%s: %d times" % (ir.name, multi))
             self.cgnames.append((ir.name, multi, dim))
             self.cg.append(np.asarray(lcoeffs).copy())
             self.cgind.append(np.asarray(lind).copy())
     #print("before saving as array")
     #print(self.cg)
     self.cg = np.asarray(self.cg)
     #print("after saving as array")
     #print(self.cg)
     self.cgind = np.asarray(self.cgind)
def visualize(args):

    # load model to visualize
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Creating Model...")
    num_classes = 10 if args.dataset == 'cifar10' else 100

    model = ResNet18(num_classes=num_classes).to(device)
    checkpoint = torch.load(args.checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # get template vector for given target classes
    assert len(args.target_classes) == 3
    assert all(0 <= c < num_classes for c in args.target_classes)

    template_vecs = []
    for c in args.target_classes:
        template_vecs.append(model.linear.weight[c].detach().cpu().numpy())

    # find orthonormal basis of plane using gram-schmidt
    basis = gram_schmidt(template_vecs)  # shape: (3, D)

    # to get penultimate representation of images for given target classes, change model's last layer to identity layer
    model.linear = nn.Identity()

    # load data
    train_dloader, test_dloader = load_cifar10(
    ) if args.dataset == 'cifar10' else load_cifar100()

    if args.use_train_set:
        dloader = train_dloader
    else:
        dloader = test_dloader

    representations = []
    ys = []

    for x, y in dloader:
        idx_to_use = []
        for idx in range(len(y)):
            if y[idx] in args.target_classes:
                idx_to_use.append(idx)

        if len(idx_to_use) == 0:
            continue

        x = x[idx_to_use].to(device)
        y = y[idx_to_use].to(device)

        with torch.no_grad():
            representation = model(x).detach().cpu()

        for i in range(len(y)):
            representations.append(representation[i].numpy())
            ys.append(int(y[i].item()))

    X = np.stack(representations, axis=0)  # (N * 3, D)

    # visualize
    colors = ['blue', 'red', 'green']
    c = [colors[args.target_classes.index(y)] for y in ys]

    proj_X = X @ basis.T  # (N * 3, 3)

    # NOTE: I didn't fully understand how the authors got 2d visualization, so I just used PCA.
    proj_X_2d = PCA(n_components=2).fit_transform(proj_X)  # (N * 3, 2)
    plt.scatter(proj_X_2d[:, 0], proj_X_2d[:, 1], s=3, c=c)

    # plt.show()
    plt.savefig(args.visualization_save_path)
Beispiel #5
0
    # print(labels, n_classes)

    if labels.shape[0] > 1:
        # calculate cluster centres
        cluster_centres = np.zeros(shape=(n_classes, n_features))
        for j in range(n_classes):
            j_idx = labels == j
            features_j = features[j_idx, :]
            if len(features_j.shape) < 2:
                features_j = np.expand_dims(features_j, axis=0)

            phi_j = group_net.predict(features_j)
            cluster_centres[j, :] = np.mean(phi_j, axis=0)

        # orthonormalize cluster centres
        cluster_centres, dropped, fatal = gram_schmidt(cluster_centres)
        if dropped:
            drop_count.append(drop_count[-1] + 1)
        else:
            drop_count.append(drop_count[-1])

        if fatal:
            continue

        # inter cluster distance
        temp = 0
        count = 0
        for j in range(n_classes):
            for k in range(n_classes):
                if k > j:
                    temp += norm(cluster_centres[j, :] - cluster_centres[k, :])
Beispiel #6
0
 def test_zero_projection_vector(self):
     vec0 = np.zeros((3, ))
     vec1 = np.ones((3, ))
     res = ut.gram_schmidt(vec1, vec0)
     self.assertEqual(res, vec1)
Beispiel #7
0
 def test_perpendicular_vectors(self):
     vec1 = np.asarray([1., 0., 0.])
     vec2 = np.asarray([0., 1., 0.])
     res = ut.gram_schmidt(vec1, vec2)
     self.assertEqual(res, vec1)
Beispiel #8
0
 def test_equal_vectors(self):
     vec = np.ones((3, )) / np.sqrt(3.)
     vec0 = np.zeros((3, ))
     res = ut.gram_schmidt(vec, vec)
     self.assertEqual(res, vec0)
Beispiel #9
0
def main():

    parser = argparse.ArgumentParser('PCA with Pytorch')
    parser.add_argument(
        '--method',
        default='l2rmsg',
        help=
        "can be among ['l1rmsg','l2rmsg','l12rmsg','msg','incremental','original','sgd']"
    )
    parser.add_argument('--subspace_ratio',
                        type=float,
                        default=0.5,
                        help='k/d ratio')
    parser.add_argument('--beta',
                        type=float,
                        default=0.5,
                        help='regularization const for l1')
    parser.add_argument('--lambda',
                        type=float,
                        default=1e-3,
                        help='regularization const for l2')
    parser.add_argument('--eta', type=float, default=1, help='learning rate')
    parser.add_argument(
        '--eps',
        type=float,
        default=1e-6,
        help='threshold on norm for rank-1 update of non-trivial components')
    parser.add_argument('--nepochs',
                        type=int,
                        default=20,
                        help='no. of epochs')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='To train on GPU')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='if true then progress bar gets printed')
    parser.add_argument('--log_interval',
                        type=int,
                        default=1,
                        help='log interval in epochs')
    args = parser.parse_args()
    device = lambda tens: device_templ(tens, args.cuda)

    torch.manual_seed(7)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(7)

    print('-----------------------TRAINING {}--------------------'.format(
        args.method.upper()))

    X_train, X_val, X_test = get_syn_data(device=device)
    k = int(args.subspace_ratio * X_train.size(1))
    d = X_train.size(1)
    epochs_iter = range(args.nepochs)
    epochs_iter = tqdm.tqdm(epochs_iter) if args.verbose else epochs_iter

    for epoch in epochs_iter:
        iterator = DataLoader(TensorDataset(X_train), shuffle=True)
        iterator = tqdm.tqdm(iterator) if args.verbose else iterator
        for x in iterator:
            x = x[0].squeeze()
            method = args.method
            if method in ['l1rmsg', 'l2rmsg', 'l12rmsg', 'msg']:
                if epoch == 0:
                    U = device(torch.zeros(d, k).float())
                    S = device(torch.zeros(k).float())
                U, S = msg(U, S, k, x, args.eta, args.eps, args.beta)
            elif method in 'incremental':
                if epoch == 0:
                    U = device(torch.zeros(d, k).float())
                    S = device(torch.zeros(k).float())
                U, S = incremental_update(U, S, x, max_rank=None)
                # print(U,S)
                U, S = U[:, :k], S[:k]
            elif method in 'sgd':
                if epoch == 0:
                    U = gram_schmidt(
                        nn.init.uniform_(device(torch.zeros(d, k))))
                U = stochastic_power_update(U, x, args.eta)
                U = gram_schmidt(U)
            elif method in 'original':
                _, S, V = torch.svd(X_train)
                U = V[:, :k]
                break
        if method in ['l1rmsg', 'l2rmsg', 'l12rmsg', 'msg']:
            finalU = U[:, :k]
        elif method in 'incremental':
            finalU = U
        elif method in 'sgd':
            finalU = U
        elif method in 'original':
            finalU = U
        if method in 'original':
            break
        if epoch % args.log_interval == 0:
            # print(epoch)
            print('Objective(higher is good): TRAIN {:.4f} VALIDATION {:.4f}'.
                  format(objective(X_train, finalU), objective(X_val, finalU)))
    method = args.method
    print('Objective(higher is good): TRAIN {:.4f} VALIDATION {:.4f}'.format(
        objective(X_train, finalU), objective(X_val, finalU)))