def TEP_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_bars = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) _, feat = model(x) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args, epoch) z_epoch = probs.float().to(device) Z = alpha * Z + (1. - alpha) * z_epoch z_bars = Z * (1. / (1. - alpha ** (epoch + 1))) if epoch % args.update_interval==0: print('updating target ...') args.p_targets = target_distribution(z_bars).float().to(device) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) _, feat_g = model(g_x.to(device)) prob = feat2prob(feat, model.center) prob_g = feat2prob(feat_g, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) mse_loss = F.mse_loss(prob, prob_g) loss = loss + w * mse_loss optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def PI_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) w = 0 for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): x, x_bar = x.to(device), x_bar.to(device) _, feat = model(x) _, feat_bar = model(x_bar) prob = feat2prob(feat, model.center) prob_bar = feat2prob(feat_bar, model.center) sharp_loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) consistency_loss = F.mse_loss(prob, prob_bar) loss = sharp_loss + w * consistency_loss loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args, epoch) if epoch % args.update_interval==0: print('updating target ...') args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def Baseline_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() for batch_idx, (x, label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) feat = model(x) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args) if epoch % args.update_interval == 0: print('updating target ...') args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def TEP_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) z_epoch = probs.float().to(device) Z = alpha * Z + (1. - alpha) * z_epoch z_bars = Z * (1. / (1. - alpha ** (epoch + 1))) if epoch % args.update_interval==0: args.p_targets = target_distribution(z_bars).float().to(device) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def warmup_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.warmup_lr, momentum=args.momentum, weight_decay=args.weight_decay) for epoch in range(args.warmup_epochs): loss_record = AverageMeter() model.train() for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) _, feat = model(x) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Warmup_train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args, epoch) args.p_targets = target_distribution(probs)
def TE_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to( device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) feat = model(x) prob = feat2prob(feat, model.center) z_epoch[idx, :] = prob prob_bar = Variable(z_ema[idx, :], requires_grad=False) sharp_loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) consistency_loss = F.mse_loss(prob, prob_bar) loss = sharp_loss + w * consistency_loss loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() Z = alpha * Z + (1. - alpha) * z_epoch z_ema = Z * (1. / (1. - alpha**(epoch + 1))) print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) acc, _, _, probs = test(model, eva_loader, args) if epoch % args.update_interval == 0: print('updating target ...') args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def Baseline_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) if epoch % args.update_interval==0: args.p_targets= target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
args.model_dir = model_dir + '/' + 'vgg6_{}.pth'.format(alphabetStr) args.save_txt_path = args.exp_root + '{}/{}'.format( runner_name, args.save_txt_name) train_Dloader, eval_Dloader = omniglot_alphabet_func( alphabet=alphabetStr, background=False, root=args.dataset_root)(batch_size=args.batch_size, num_workers=args.num_workers) args.n_clusters = alphabets_k_mapping[key] model = VGG(n_layer='4+2', out_dim=args.n_clusters, in_channels=1).to(device) model.load_state_dict(torch.load(args.pretrain_dir), strict=False) model.center = Parameter(torch.Tensor(args.n_clusters, args.n_clusters)) init_centers, init_probs = init_prob_kmeans(init_feat_extractor, eval_Dloader, args) args.p_targets = target_distribution(init_probs) model.center.data = torch.tensor(init_centers).float().to(device) warmup_train(model, alphabetStr, train_Dloader, eval_Dloader, args) train(model, alphabetStr, train_Dloader, eval_Dloader, args) acc[alphabetStr], nmi[alphabetStr], ari[alphabetStr], _ = test( model, eval_Dloader, args) print('ACC for all alphabets:', acc) print('NMI for all alphabets:', nmi) print('ARI for all alphabets:', ari) avg_acc, avg_nmi, avg_ari = sum(acc.values()) / float(len(acc)), sum( nmi.values()) / float(len(nmi)), sum(ari.values()) / float(len(ari)) print('avg ACC {:.4f}, NMI {:.4f} ARI {:.4f}'.format( avg_acc, avg_nmi, avg_ari)) if args.save_txt: with open(args.save_txt_path, 'a') as f: