Example #1
0
def TEP_train(model, train_loader, eva_loader, args):
    optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
    w = 0
    alpha = 0.6
    ntrain = len(train_loader.dataset)
    Z = torch.zeros(ntrain, args.n_clusters).float().to(device)        # intermediate values
    z_bars = torch.zeros(ntrain, args.n_clusters).float().to(device)        # temporal outputs
    z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device)  # current outputs
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        exp_lr_scheduler.step()
        for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)):
            x = x.to(device) 
            _, feat = model(x)
            prob = feat2prob(feat, model.center)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        _, _, _, probs = test(model, eva_loader, args, epoch)
        z_epoch = probs.float().to(device)
        Z = alpha * Z + (1. - alpha) * z_epoch
        z_bars = Z * (1. / (1. - alpha ** (epoch + 1)))

        if epoch % args.update_interval==0:
            print('updating target ...')
            args.p_targets = target_distribution(z_bars).float().to(device) 
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Example #2
0
def train(model, alphabetStr, train_loader, eval_loader, args):
    optimizer = Adam(model.parameters(), lr=args.lr)
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)
        for batch_idx, (x, g_x, _, idx) in enumerate(train_loader):
            _, feat = model(x.to(device))
            _, feat_g = model(g_x.to(device))
            prob = feat2prob(feat, model.center)
            prob_g = feat2prob(feat_g, model.center)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            mse_loss = F.mse_loss(prob, prob_g)
            loss = loss + w * mse_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_record.update(loss.item(), x.size(0))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        _, _, _, probs = test(model, eval_loader, args)
        args.p_targets = target_distribution(probs)
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Example #3
0
def PI_train(model, train_loader, eva_loader, args):
    optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
    w = 0
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        exp_lr_scheduler.step()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) 
        for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)):
            x, x_bar = x.to(device), x_bar.to(device)
            _, feat = model(x)
            _, feat_bar = model(x_bar)
            prob = feat2prob(feat, model.center)
            prob_bar = feat2prob(feat_bar, model.center)
            sharp_loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            consistency_loss = F.mse_loss(prob, prob_bar)
            loss = sharp_loss + w * consistency_loss 
            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        _, _, _, probs = test(model, eva_loader, args, epoch)

        if epoch % args.update_interval==0:
            print('updating target ...')
            args.p_targets = target_distribution(probs) 
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Example #4
0
def Baseline_train(model, train_loader, eva_loader, args):
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer,
                                                milestones=args.milestones,
                                                gamma=args.gamma)
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        exp_lr_scheduler.step()
        for batch_idx, (x, label, idx) in enumerate(tqdm(train_loader)):
            x = x.to(device)
            feat = model(x)
            prob = feat2prob(feat, model.center)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        _, _, _, probs = test(model, eva_loader, args)
        if epoch % args.update_interval == 0:
            print('updating target ...')
            args.p_targets = target_distribution(probs)
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Example #5
0
def TEP_train(model, alphabetStr, train_loader, eval_loader, args):
    optimizer = Adam(model.parameters(), lr=args.lr)
    w = 0
    alpha = 0.6
    ntrain = len(train_loader.dataset)
    Z = torch.zeros(ntrain, args.n_clusters).float().to(device)        # intermediate values
    z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device)        # temporal outputs
    z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device)  # current outputs

    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        for batch_idx, (x, g_x, _, idx) in enumerate(train_loader):
            _, feat = model(x.to(device))
            prob = feat2prob(feat, model.center)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_record.update(loss.item(), x.size(0))

        print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        _, _, _, probs = test(model, eval_loader, args)
        z_epoch = probs.float().to(device)
        Z = alpha * Z + (1. - alpha) * z_epoch
        z_bars = Z * (1. / (1. - alpha ** (epoch + 1)))

        if epoch % args.update_interval==0:
            args.p_targets = target_distribution(z_bars).float().to(device) 
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Example #6
0
def warmup_train(model, train_loader, eva_loader, args):
    optimizer = SGD(model.parameters(), lr=args.warmup_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    for epoch in range(args.warmup_epochs):
        loss_record = AverageMeter()
        model.train()
        for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)):
            x = x.to(device)
            _, feat = model(x)
            prob = feat2prob(feat, model.center)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Warmup_train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        _, _, _, probs = test(model, eva_loader, args, epoch)
    args.p_targets = target_distribution(probs) 
Example #7
0
def TE_train(model, train_loader, eva_loader, args):
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    w = 0
    alpha = 0.6
    ntrain = len(train_loader.dataset)
    Z = torch.zeros(ntrain,
                    args.n_clusters).float().to(device)  # intermediate values
    z_ema = torch.zeros(ntrain,
                        args.n_clusters).float().to(device)  # temporal outputs
    z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(
        device)  # current outputs
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        w = args.rampup_coefficient * ramps.sigmoid_rampup(
            epoch, args.rampup_length)
        for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)):
            x = x.to(device)
            feat = model(x)
            prob = feat2prob(feat, model.center)
            z_epoch[idx, :] = prob
            prob_bar = Variable(z_ema[idx, :], requires_grad=False)
            sharp_loss = F.kl_div(prob.log(),
                                  args.p_targets[idx].float().to(device))
            consistency_loss = F.mse_loss(prob, prob_bar)
            loss = sharp_loss + w * consistency_loss
            loss_record.update(loss.item(), x.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        Z = alpha * Z + (1. - alpha) * z_epoch
        z_ema = Z * (1. / (1. - alpha**(epoch + 1)))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(
            epoch, loss_record.avg))
        acc, _, _, probs = test(model, eva_loader, args)

        if epoch % args.update_interval == 0:
            print('updating target ...')
            args.p_targets = target_distribution(probs)
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Example #8
0
def Baseline_train(model, alphabetStr, train_loader, eval_loader, args):
    optimizer = Adam(model.parameters(), lr=args.lr)
    for epoch in range(args.epochs):
        loss_record = AverageMeter()
        model.train()
        for batch_idx, (x, g_x, _, idx) in enumerate(train_loader):
            _, feat = model(x.to(device))
            prob = feat2prob(feat, model.center)
            loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_record.update(loss.item(), x.size(0))
        print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg))
        _, _, _, probs = test(model, eval_loader, args)

        if epoch % args.update_interval==0:
            args.p_targets= target_distribution(probs) 
    torch.save(model.state_dict(), args.model_dir)
    print("model saved to {}.".format(args.model_dir))
Example #9
0
        args.model_dir = model_dir + '/' + 'vgg6_{}.pth'.format(alphabetStr)
        args.save_txt_path = args.exp_root + '{}/{}'.format(
            runner_name, args.save_txt_name)
        train_Dloader, eval_Dloader = omniglot_alphabet_func(
            alphabet=alphabetStr, background=False,
            root=args.dataset_root)(batch_size=args.batch_size,
                                    num_workers=args.num_workers)
        args.n_clusters = alphabets_k_mapping[key]
        model = VGG(n_layer='4+2', out_dim=args.n_clusters,
                    in_channels=1).to(device)
        model.load_state_dict(torch.load(args.pretrain_dir), strict=False)
        model.center = Parameter(torch.Tensor(args.n_clusters,
                                              args.n_clusters))
        init_centers, init_probs = init_prob_kmeans(init_feat_extractor,
                                                    eval_Dloader, args)
        args.p_targets = target_distribution(init_probs)
        model.center.data = torch.tensor(init_centers).float().to(device)
        warmup_train(model, alphabetStr, train_Dloader, eval_Dloader, args)
        train(model, alphabetStr, train_Dloader, eval_Dloader, args)
        acc[alphabetStr], nmi[alphabetStr], ari[alphabetStr], _ = test(
            model, eval_Dloader, args)
    print('ACC for all alphabets:', acc)
    print('NMI for all alphabets:', nmi)
    print('ARI for all alphabets:', ari)
    avg_acc, avg_nmi, avg_ari = sum(acc.values()) / float(len(acc)), sum(
        nmi.values()) / float(len(nmi)), sum(ari.values()) / float(len(ari))
    print('avg ACC {:.4f}, NMI {:.4f} ARI {:.4f}'.format(
        avg_acc, avg_nmi, avg_ari))

    if args.save_txt:
        with open(args.save_txt_path, 'a') as f: