def train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) _, feat_g = model(g_x.to(device)) prob = feat2prob(feat, model.center) prob_g = feat2prob(feat_g, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) mse_loss = F.mse_loss(prob, prob_g) loss = loss + w * mse_loss optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def PI_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) w = 0 for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() w = args.rampup_coefficient * ramps.sigmoid_rampup(epoch, args.rampup_length) for batch_idx, ((x, x_bar), label, idx) in enumerate(tqdm(train_loader)): x, x_bar = x.to(device), x_bar.to(device) _, feat = model(x) _, feat_bar = model(x_bar) prob = feat2prob(feat, model.center) prob_bar = feat2prob(feat_bar, model.center) sharp_loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) consistency_loss = F.mse_loss(prob, prob_bar) loss = sharp_loss + w * consistency_loss loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args, epoch) if epoch % args.update_interval==0: print('updating target ...') args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def init_prob_kmeans(model, eval_loader, args): torch.manual_seed(1) model = model.to(device) # cluster parameter initiate model.eval() targets = np.zeros(len(eval_loader.dataset)) feats = np.zeros((len(eval_loader.dataset), 1024)) for _, (x, _, label, idx) in enumerate(eval_loader): x = x.to(device) _, feat = model(x) feat = feat.view(x.size(0), -1) idx = idx.data.cpu().numpy() feats[idx, :] = feat.data.cpu().numpy() targets[idx] = label.data.cpu().numpy() # evaluate clustering performance pca = PCA(n_components=args.n_clusters) feats = pca.fit_transform(feats) kmeans = KMeans(n_clusters=args.n_clusters, n_init=20) y_pred = kmeans.fit_predict(feats) acc, nmi, ari = cluster_acc(targets, y_pred), nmi_score(targets, y_pred), ari_score( targets, y_pred) print('Init acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) probs = feat2prob(torch.from_numpy(feats), torch.from_numpy(kmeans.cluster_centers_)) return kmeans.cluster_centers_, probs
def Baseline_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() for batch_idx, (x, label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) feat = model(x) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args) if epoch % args.update_interval == 0: print('updating target ...') args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def TEP_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_bars = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() exp_lr_scheduler.step() for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) _, feat = model(x) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args, epoch) z_epoch = probs.float().to(device) Z = alpha * Z + (1. - alpha) * z_epoch z_bars = Z * (1. / (1. - alpha ** (epoch + 1))) if epoch % args.update_interval==0: print('updating target ...') args.p_targets = target_distribution(z_bars).float().to(device) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def TEP_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to(device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) z_epoch = probs.float().to(device) Z = alpha * Z + (1. - alpha) * z_epoch z_bars = Z * (1. / (1. - alpha ** (epoch + 1))) if epoch % args.update_interval==0: args.p_targets = target_distribution(z_bars).float().to(device) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def warmup_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.warmup_lr) for epoch in range(args.warmup_epochs): loss_record = AverageMeter() model.train() for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Warmup Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) test(model, eval_loader, args)
def warmup_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.warmup_lr, momentum=args.momentum, weight_decay=args.weight_decay) for epoch in range(args.warmup_epochs): loss_record = AverageMeter() model.train() for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) _, feat = model(x) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print('Warmup_train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eva_loader, args, epoch) args.p_targets = target_distribution(probs)
def TE_train(model, train_loader, eva_loader, args): optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) w = 0 alpha = 0.6 ntrain = len(train_loader.dataset) Z = torch.zeros(ntrain, args.n_clusters).float().to(device) # intermediate values z_ema = torch.zeros(ntrain, args.n_clusters).float().to(device) # temporal outputs z_epoch = torch.zeros(ntrain, args.n_clusters).float().to( device) # current outputs for epoch in range(args.epochs): loss_record = AverageMeter() model.train() w = args.rampup_coefficient * ramps.sigmoid_rampup( epoch, args.rampup_length) for batch_idx, ((x, _), label, idx) in enumerate(tqdm(train_loader)): x = x.to(device) feat = model(x) prob = feat2prob(feat, model.center) z_epoch[idx, :] = prob prob_bar = Variable(z_ema[idx, :], requires_grad=False) sharp_loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) consistency_loss = F.mse_loss(prob, prob_bar) loss = sharp_loss + w * consistency_loss loss_record.update(loss.item(), x.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() Z = alpha * Z + (1. - alpha) * z_epoch z_ema = Z * (1. / (1. - alpha**(epoch + 1))) print('Train Epoch: {} Avg Loss: {:.4f}'.format( epoch, loss_record.avg)) acc, _, _, probs = test(model, eva_loader, args) if epoch % args.update_interval == 0: print('updating target ...') args.p_targets = target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def test(model, test_loader, args, epoch='test'): model.eval() preds=np.array([]) targets=np.array([]) feats = np.zeros((len(test_loader.dataset), args.n_clusters)) probs= np.zeros((len(test_loader.dataset), args.n_clusters)) for batch_idx, (x, label, idx) in enumerate(tqdm(test_loader)): x, label = x.to(device), label.to(device) _, feat = model(x) prob = feat2prob(feat, model.center) _, pred = prob.max(1) targets=np.append(targets, label.cpu().numpy()) preds=np.append(preds, pred.cpu().numpy()) idx = idx.data.cpu().numpy() feats[idx, :] = feat.cpu().detach().numpy() probs[idx, :] = prob.cpu().detach().numpy() acc, nmi, ari = cluster_acc(targets.astype(int), preds.astype(int)), nmi_score(targets, preds), ari_score(targets, preds) print('Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) probs = torch.from_numpy(probs) return acc, nmi, ari, probs
def Baseline_train(model, alphabetStr, train_loader, eval_loader, args): optimizer = Adam(model.parameters(), lr=args.lr) for epoch in range(args.epochs): loss_record = AverageMeter() model.train() for batch_idx, (x, g_x, _, idx) in enumerate(train_loader): _, feat = model(x.to(device)) prob = feat2prob(feat, model.center) loss = F.kl_div(prob.log(), args.p_targets[idx].float().to(device)) optimizer.zero_grad() loss.backward() optimizer.step() loss_record.update(loss.item(), x.size(0)) print('Train Epoch: {} Avg Loss: {:.4f}'.format(epoch, loss_record.avg)) _, _, _, probs = test(model, eval_loader, args) if epoch % args.update_interval==0: args.p_targets= target_distribution(probs) torch.save(model.state_dict(), args.model_dir) print("model saved to {}.".format(args.model_dir))
def test(model, eval_loader, args): model.eval() targets = np.zeros(len(eval_loader.dataset)) y_pred = np.zeros(len(eval_loader.dataset)) probs= np.zeros((len(eval_loader.dataset), args.n_clusters)) for _, (x, _, label, idx) in enumerate(eval_loader): x = x.to(device) _, feat = model(x) prob = feat2prob(feat, model.center) # prob = F.softmax(logit, dim=1) idx = idx.data.cpu().numpy() y_pred[idx] = prob.data.cpu().detach().numpy().argmax(1) targets[idx] = label.data.cpu().numpy() probs[idx, :] = prob.cpu().detach().numpy() # evaluate clustering performance y_pred = y_pred.astype(np.int64) acc, nmi, ari = cluster_acc(targets, y_pred), nmi_score(targets, y_pred), ari_score(targets, y_pred) print('Test acc {:.4f}, nmi {:.4f}, ari {:.4f}'.format(acc, nmi, ari)) probs = torch.from_numpy(probs) return acc, nmi, ari, probs