Ejemplo n.º 1
0
def finetune(model):
    print("Start finetuning")
    dataset = Dataset(split='gold', use_gold=True)
    loader = torch.utils.data.DataLoader(dataset,
                                         128,
                                         shuffle=True,
                                         num_workers=0,
                                         collate_fn=collate_samples)
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    optim = torch.optim.Adam(model.parameters())
    model.train()
    vat_loss = VATLoss()
    for _, batch in enumerate(loader):
        move_batch(batch)
        optim.zero_grad()

        xs = model(**batch, feature_only=True)
        lds = vat_loss(model, xs)
        logits = model(xs=xs)
        pos_weight = 1. + (batch['ys'] * 8.)
        loss = (pos_weight * criterion(logits, batch['ys'])).mean() + .1 * lds
        loss.backward()
        optim.step()

        del batch
    validate(model)
Ejemplo n.º 2
0
def train(model, device, loader, optimizer, config):
    model.train()

    correct = 0
    total_loss = 0

    for batch_idx, sample in enumerate(loader):
        data = sample['image']
        target = sample['target']

        if config.vat:
            vat_loss = VATLoss(xi=config.xi, eps=config.eps, ip=config.ip)
            ul_data = sample['ul_image']
            ul_data = ul_data.to(device)
            lds = vat_loss(model, ul_data)

        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss_fun = torch.nn.CrossEntropyLoss()

        if config.null_space_tuning:
            ul_data1 = sample['ul_img1']
            ul_data2 = sample['ul_img2']

            ul_data1, ul_data2 = ul_data1.to(device), ul_data2.to(device)

            ul_out1 = model(ul_data1)
            ul_out2 = model(ul_data2)

            null_loss_fun = torch.nn.MSELoss()
            loss = loss_fun(output, target) + config.alpha * null_loss_fun(
                ul_out1, ul_out2)
        elif config.vat:
            loss = loss_fun(output, target) + config.alpha * lds
        else:
            loss = loss_fun(output, target)

        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(target.view(-1, 1)).sum().item()

        total_loss += loss.item()
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / (batch_idx + 1)
    accuracy = 100 * (correct / loader.dataset.len)

    print('\tTraining set: Average loss: {:.4f}, Accuracy: {:.0f}%'.format(
        avg_loss, accuracy))

    return avg_loss, accuracy
Ejemplo n.º 3
0
def train(args, model, device, data_iterators, optimizer):
    model.train()
    for i in tqdm(range(args.iters)):

        # reset
        if i % args.log_interval == 0:
            ce_losses = utils.AverageMeter()
            vat_losses = utils.AverageMeter()
            prec1 = utils.AverageMeter()

        x_l, y_l = next(data_iterators['labeled'])
        x_ul, _ = next(data_iterators['unlabeled'])

        x_l, y_l = x_l.to(device), y_l.to(device)
        x_ul = x_ul.to(device)

        optimizer.zero_grad()

        vat_loss = VATLoss(xi=args.xi, eps=args.eps, ip=args.ip)
        cross_entropy = nn.CrossEntropyLoss()

        lds = vat_loss(model, x_ul)
        output = model(x_l)
        classification_loss = cross_entropy(output, y_l)
        loss = classification_loss + args.alpha * lds
        loss.backward()
        optimizer.step()

        acc = utils.accuracy(output, y_l)
        ce_losses.update(classification_loss.item(), x_l.shape[0])
        vat_losses.update(lds.item(), x_ul.shape[0])
        prec1.update(acc.item(), x_l.shape[0])

        if i % args.log_interval == 0:
            print(
                f'\nIteration: {i}\t'
                f'CrossEntropyLoss {ce_losses.val:.4f} ({ce_losses.avg:.4f})\t'
                f'VATLoss {vat_losses.val:.4f} ({vat_losses.avg:.4f})\t'
                f'Prec@1 {prec1.val:.3f} ({prec1.avg:.3f})')
Ejemplo n.º 4
0
def scAdapt(args, data_set):
    ## prepare data
    batch_size = args.batch_size
    kwargs = {'num_workers': 0, 'pin_memory': True}

    source_name = args.source_name #"TM_baron_mouse_for_baron"
    target_name = args.target_name #"baron_human"
    domain_to_indices = np.where(data_set['accessions'] == source_name)[0]
    train_set = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices],
                 'accessions': data_set['accessions'][domain_to_indices]}
    domain_to_indices = np.where(data_set['accessions'] == target_name)[0]
    test_set = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices],
                'accessions': data_set['accessions'][domain_to_indices]}
    print('source labels:', np.unique(train_set['labels']), ' target labels:', np.unique(test_set['labels']))
    test_set_eval = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices],
                     'accessions': data_set['accessions'][domain_to_indices]}
    print(train_set['features'].shape, test_set['features'].shape)

    data = torch.utils.data.TensorDataset(
        torch.FloatTensor(train_set['features']), torch.LongTensor(matrix_one_hot(train_set['labels'], int(max(train_set['labels'])+1)).long()))
    source_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)

    data = torch.utils.data.TensorDataset(
        torch.FloatTensor(test_set['features']), torch.LongTensor(matrix_one_hot(test_set['labels'], int(max(train_set['labels'])+1)).long()))
    target_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
    target_test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=False,
                                                     **kwargs)
    class_num = max(train_set['labels'])+1
    class_num_test = max(test_set['labels']) + 1

    ### re-weighting the classifier
    cls_num_list = [np.sum(train_set['labels'] == i) for i in range(class_num)]
    #from https://github.com/YyzHarry/imbalanced-semi-self/blob/master/train.py
    # # Normalized weights based on inverse number of effective data per class.
    #2019 Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss
    #2020 Rethinking the Value of Labels for Improving Class-Imbalanced Learning
    beta = 0.9999
    effective_num = 1.0 - np.power(beta, cls_num_list)
    per_cls_weights = (1.0 - beta) / np.array(effective_num)
    per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
    per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()


    ## set base network
    embedding_size = args.embedding_size
    base_network = FeatureExtractor(num_inputs=train_set['features'].shape[1], embed_size = embedding_size).cuda()
    label_predictor = LabelPredictor(base_network.output_num(), class_num).cuda()
    total_model = nn.Sequential(base_network, label_predictor)

    center_loss = CenterLoss(num_classes=class_num, feat_dim=embedding_size, use_gpu=True)
    optimizer_centloss = torch.optim.SGD([{'params': center_loss.parameters()}], lr=0.5)

    print("output size of FeatureExtractor and LabelPredictor: ", base_network.output_num(), class_num)
    ad_net = scAdversarialNetwork(base_network.output_num(), 1024).cuda()

    ## set optimizer
    config_optimizer = {"lr_type": "inv", "lr_param": {"lr": 0.001, "gamma": 0.001, "power": 0.75}}
    parameter_list = base_network.get_parameters() + ad_net.get_parameters() + label_predictor.get_parameters()
    optimizer = optim.SGD(parameter_list, lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True)
    schedule_param = config_optimizer["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[config_optimizer["lr_type"]]

    ## train
    len_train_source = len(source_loader)
    len_train_target = len(target_loader)
    transfer_loss_value = classifier_loss_value = total_loss_value = 0.0
    epoch_global = 0.0

    hit = False
    s_global_centroid = torch.zeros(class_num, embedding_size).cuda()
    t_global_centroid = torch.zeros(class_num, embedding_size).cuda()
    for epoch in range(args.num_iterations):
        if epoch % (2500) == 0 and epoch != 0:
            feature_target = base_network(torch.FloatTensor(test_set['features']).cuda())
            output_target = label_predictor.forward(feature_target)
            softmax_out = nn.Softmax(dim=1)(output_target)
            predict_prob_arr, predict_label_arr = torch.max(softmax_out, 1)
            if epoch == args.epoch_th:
                data = torch.utils.data.TensorDataset(torch.FloatTensor(test_set['features']), predict_label_arr.cpu())
                target_loader_align = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True,
                                                            **kwargs)

            result_path = args.result_path #"../results/"
            model_file = result_path + 'final_model_' + str(epoch) + source_name + target_name+'.ckpt'
            torch.save({'base_network': base_network.state_dict(), 'label_predictor': label_predictor.state_dict()}, model_file)

            if not os.path.exists(result_path):
                os.makedirs(result_path)
            with torch.no_grad():
                code_arr_s = base_network(Variable(torch.FloatTensor(train_set['features']).cuda()))
                code_arr_t = base_network(Variable(torch.FloatTensor(test_set_eval['features']).cuda()))
                code_arr = np.concatenate((code_arr_s.cpu().data.numpy(), code_arr_t.cpu().data.numpy()), 0)

            digit_label_dict = pd.read_csv(args.dataset_path + 'digit_label_dict.csv')
            digit_label_dict = pd.DataFrame(zip(digit_label_dict.iloc[:,0], digit_label_dict.index), columns=['digit','label'])
            digit_label_dict = digit_label_dict.to_dict()['label']
            # transform digit label to cell type name
            y_pred_label = [digit_label_dict[x] if x in digit_label_dict else x for x in predict_label_arr.cpu().data.numpy()]

            pred_labels_file = result_path + 'pred_labels_' + source_name + "_" + target_name + "_" + str(epoch) + ".csv"
            pd.DataFrame([predict_prob_arr.cpu().data.numpy(), y_pred_label],  index=["pred_probability", "pred_label"]).to_csv(pred_labels_file, sep=',')
            embedding_file = result_path + 'embeddings_' + source_name + "_" + target_name + "_" + str(epoch)+ ".csv"
            pd.DataFrame(code_arr).to_csv(embedding_file, sep=',')

            #### only for evaluation
            # acc_by_label = np.zeros( class_num_test )
            # all_label = test_set['labels']
            # for i in range(class_num_test):
            #     acc_by_label[i] = np.sum(predict_label_arr.cpu().data.numpy()[all_label == i] == i) / np.sum(all_label == i)
            # np.set_printoptions(suppress=True)
            # print('iter:', epoch, "average acc over all test cell types: ", round(np.nanmean(acc_by_label), 3))
            # print("acc of each test cell type: ", acc_by_label)

            # div_score, div_score_all, ent_score, sil_score = evaluate_multibatch(code_arr, train_set, test_set_eval, epoch)
            #results_file = result_path + source_name + "_" + target_name + "_" + str(epoch)+ "_acc_div_sil.csv"
            #evel_res = [np.nanmean(acc_by_label), div_score, div_score_all, ent_score, sil_score]
            #pd.DataFrame(evel_res, index = ["acc","div_score","div_score_all","ent_score","sil_score"], columns=["values"]).to_csv(results_file, sep=',')
            # pred_labels_file = result_path + source_name + "_" + target_name + "_" + str(epoch) + "_pred_labels.csv"
            # pd.DataFrame([predict_label_arr.cpu().data.numpy(), all_label],  index=["pred_label", "true_label"]).to_csv(pred_labels_file, sep=',')



        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        label_predictor.train(True)

        optimizer = lr_scheduler(optimizer, epoch, **schedule_param)
        optimizer.zero_grad()
        optimizer_centloss.zero_grad()

        if epoch % len_train_source == 0:
            iter_source = iter(source_loader)
            epoch_global = epoch_global + 1
        if epoch % len_train_target == 0:
            if epoch < args.epoch_th:
                iter_target = iter(target_loader)
            else:
                hit = True
                iter_target = iter(target_loader_align)
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source, labels_target = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda(), labels_target.cuda()

        feature_source = base_network(inputs_source)
        feature_target = base_network(inputs_target)
        features = torch.cat((feature_source, feature_target), dim=0)

        output_source = label_predictor.forward(feature_source)
        output_target = label_predictor.forward(feature_target)

        ######## VAT and BNM loss
        # LDS should be calculated before the forward for cross entropy
        vat_loss = VATLoss(xi=args.xi, eps=args.eps, ip=args.ip)
        lds_loss = vat_loss(total_model, inputs_target)

        softmax_tgt = nn.Softmax(dim=1)(output_target[:, 0:class_num])
        _, s_tgt, _ = torch.svd(softmax_tgt)
        BNM_loss = -torch.mean(s_tgt)

        ########domain alignment loss
        if args.method == 'DANN':
            domain_prob_discriminator_1_source = ad_net.forward(feature_source)
            domain_prob_discriminator_1_target = ad_net.forward(feature_target)

            adv_loss = loss_utility.BCELossForMultiClassification(label=torch.ones_like(domain_prob_discriminator_1_source), \
                                                     predict_prob=domain_prob_discriminator_1_source)  # domain matching
            adv_loss += loss_utility.BCELossForMultiClassification(label=torch.ones_like(domain_prob_discriminator_1_target), \
                                                      predict_prob=1 - domain_prob_discriminator_1_target)

            transfer_loss = adv_loss
        elif args.method == 'mmd':
            base = 1.0  # sigma for MMD
            sigma_list = [1, 2, 4, 8, 16]
            sigma_list = [sigma / base for sigma in sigma_list]
            transfer_loss = loss_utility.mix_rbf_mmd2(feature_source, feature_target, sigma_list)


        ######CrossEntropyLoss
        classifier_loss = nn.CrossEntropyLoss(weight=per_cls_weights)(output_source, torch.max(labels_source, dim=1)[1])
        # classifier_loss = loss_utility.CrossEntropyLoss(labels_source.float(), nn.Softmax(dim=1)(output_source))

        ######semantic_loss and center loss
        cell_th = args.cell_th
        epoch_th = args.epoch_th
        if epoch < args.epoch_th or hit == False:
            semantic_loss = torch.FloatTensor([0.0]).cuda()
            center_loss_src = torch.FloatTensor([0.0]).cuda()
            sum_dist_loss = torch.FloatTensor([0.0]).cuda()
            # center_loss.centers = feature_source[torch.max(labels_source, dim=1)[1] == 0].mean(dim=0, keepdim=True)
            pass
        elif hit == True:
            center_loss_src = center_loss(feature_source,
                                          labels=torch.max(labels_source, dim=1)[1])
            s_global_centroid = center_loss.centers
            semantic_loss, s_global_centroid, t_global_centroid = loss_utility.semant_use_s_center(class_num,
                                                                                           s_global_centroid,
                                                                                           t_global_centroid,
                                                                                           feature_source,
                                                                                           feature_target,
                                                                                           torch.max(
                                                                                               labels_source,
                                                                                               dim=1)[1],
                                                                                           labels_target, 0.7,
                                                                                           cell_th) #softmax_tgt

        if epoch > epoch_th:
            lds_loss = torch.FloatTensor([0.0]).cuda()
        if epoch <= args.num_iterations:
            progress = epoch / args.epoch_th #args.num_iterations
        else:
            progress = 1
        lambd = 2 / (1 + math.exp(-10 * progress)) - 1

        total_loss = classifier_loss + lambd*args.DA_coeff * transfer_loss + lambd*args.BNM_coeff*BNM_loss + lambd*args.alpha*lds_loss\
        + args.semantic_coeff *semantic_loss + args.centerloss_coeff*center_loss_src

        total_loss.backward()
        optimizer.step()

        # multiple (1./centerloss_coeff) in order to remove the effect of centerloss_coeff on updating centers
        if args.centerloss_coeff > 0 and center_loss_src > 0:
            for param in center_loss.parameters():
                param.grad.data *= (1. / args.centerloss_coeff)
            optimizer_centloss.step() #optimize the center in center loss
Ejemplo n.º 5
0
     #calculate Vst
     Vst = Classification_loss + args.lambda_val * Adversarial_DA_loss
 
     t_class_output,_, domain_output_mt,_ = my_net(t_imgs,alpha=alpha)
     
     #calculate Vmt
     if args.dataset_name == "OfficeHome":
         Vmt = loss_class(domain_output_mt, t_labels)
     else:
         t_labels = torch.unsqueeze(t_labels, 1)
         t_labels = t_labels.float()
         Vmt = loss_domain(domain_output_mt, t_labels)
     #calculate Lent
     Lent = loss_entropy(t_class_output)
     #calculate Lvir
     vat_loss = VATLoss(xi=10.0, eps=1.0, ip=1)
     s_vat_loss = vat_loss(my_net,s_imgs)
     t_vat_loss = vat_loss(my_net,t_imgs)
     Lvir = s_vat_loss + args.rho_val * t_vat_loss
             
     loss = Vst + args.gamma_val*Vmt + args.beta_val*Lent + Lvir 
     loss.backward()
     optim_net.step()              
     
     
     if iter_count%10 == 0:
         tqdm.write("iter is : {}  Classification loss is :{:.3f} Lent is :{:.3f}  Adversarial DA loss is :{:.3f} Vmt is:{:.3f} gamma_val is:{:.3f}".format(iter_count,Classification_loss.item(),\
             Lent.item(), Adversarial_DA_loss.item(),Vmt.item(), args.gamma_val))
         print_log(iter_count+1, \
         Classification_loss.item(), \
         Lent.item(), \
Ejemplo n.º 6
0
    writer.add_figure(nametag, fig, epoch)


### Model
input_size = 512
num_feature = 1024
class_num = 10

generator_g = net.Encoder().cuda()
classifier_c = net.Classifier(512, class_num).cuda()
classifier_j = net.Classifier(512, class_num * 2).cuda()

### Loss & Optimizers
# Loss functions
criterion_CE = torch.nn.CrossEntropyLoss().cuda()
criterion_VAT = VATLoss().cuda()

# Optimizers

generator_g_optim = torch.optim.Adam(generator_g.parameters(), lr=opt.lr)
classifier_c_optim = torch.optim.Adam(classifier_c.parameters(), lr=opt.lr)
classifier_j_optim = torch.optim.Adam(classifier_j.parameters(), lr=opt.lr)

# ----------
#  Training
# ----------
print('')
niter = 0
epoch = 0
best_test = 0
Ejemplo n.º 7
0
def train_vat(fold=2, disable_progress=False):
    directory = './data/'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    train_df = pd.read_csv('folds.csv')
    train_path = os.path.join(directory, 'train')

    train_ids = train_df[train_df.fold != fold]['id'].values
    val_ids = train_df[train_df.fold == fold]['id'].values

    dataset = TGSSaltDataset(train_path, train_ids, augment=True)
    dataset_val = TGSSaltDataset(train_path, val_ids)

    test_path = os.path.join(directory, 'test')
    test_file_list = glob.glob(os.path.join(test_path, 'images', '*.png'))
    test_file_list = [f.split('/')[-1].split('.')[0] for f in test_file_list]
    test_file_list[:3], test_path
    dataset_test = TGSSaltDataset(test_path, test_file_list, is_test=True)

    model = UnetModel()
    model.train()
    model.to(device)

    #load_auto(f'auto-new-{num_filters}-s2-best.pth', model, device)
    #set_encoder_train(model, False)

    epoch = 100
    learning_rate = 5e-3
    alpha = 1.0
    #loss_fn = torch.nn.BCEWithLogitsLoss()
    vat_loss = VATLoss(eps=1.0, ip=1)
    empty_loss_fn = nn.BCEWithLogitsLoss()
    ent_loss_fn = EntropyLoss()
    #mask = get_mask().to(device)
    loss_fn = LovaszLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=learning_rate,
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=10,
                                                           factor=0.5)

    try:
        load_checkpoint(f'unet-fold{fold}-s2-best.pth', model, optimizer)
    except:
        print('oops no file')

    train_iter = data.DataLoader(dataset, batch_size=16, shuffle=True)
    test_iter = data.DataLoader(dataset_test, batch_size=16, shuffle=True)
    best_iou = 0
    for e in range(epoch):
        train_loss = []
        smooth_loss = []
        unlabeled_loss = []
        #for sample in tqdm(data.DataLoader(dataset, batch_size = 30, shuffle = True)):
        for sample, test_sample in zip(tqdm(train_iter), test_iter):
            image, mask = sample['image'], sample['mask']
            image = image.type(torch.float).to(device)
            test_image = test_sample['image']
            test_image = test_image.type(torch.float).to(device)

            lds = 0  # vat_loss(model, image)
            lds_test = 0  # vat_loss(model, test_image)
            y_pred, y_pred_empty, _ = model(image)
            test_pred, test_pred_empty, _ = model(test_image)
            ul_loss = ent_loss_fn(test_pred_empty)
            direct_loss = loss_fn(y_pred, mask.to(device))
            class_loss = empty_loss_fn(y_pred_empty,
                                       empty_mask(mask).to(device))
            loss = direct_loss + 0.05 * (class_loss + alpha *
                                         (lds + lds_test) + ul_loss)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            train_loss.append(direct_loss.item())
            smooth_loss.append(0)  #lds.item() + lds_test.item())
            unlabeled_loss.append(ul_loss.item())

        val_loss = []
        val_iou = []
        for sample in data.DataLoader(dataset_val,
                                      batch_size=16,
                                      shuffle=False):
            image, mask = sample['image'], sample['mask']
            image = image.to(device)
            y_pred, _, _ = model(image)

            loss = loss_fn(y_pred, mask.to(device))
            val_loss.append(loss.item())
            iou = iou_pytorch((y_pred > 0).int(), mask.int().to(device)).cpu()
            val_iou.append(iou)

        avg_iou = np.mean(val_iou)
        scheduler.step(np.mean(val_loss))
        print(
            "Epoch: %d, Train: %.3f, Smooth: %.3f, UL: %.3f, Val: %.3f, IoU: %.3f"
            % (e, np.mean(train_loss), np.mean(smooth_loss),
               np.mean(unlabeled_loss), np.mean(val_loss), avg_iou))
        if avg_iou > best_iou:
            print('saving new best')
            save_checkpoint(f'unet-fold{fold}-vat-best.pth', model, optimizer)
            best_iou = avg_iou

    print('Best IoU: %.3f' % best_iou)
Ejemplo n.º 8
0
def train():
    LOG_FREQ = 10
    VAL_FREQ = 2000
    QUERY_FREQ = 400
    QUERY_SIZE = 20

    with open('data/priors.pkl', 'rb') as f:
        priors = pickle.load(f)
    mean_prior = np.mean(list(priors.values()))

    dataset = Dataset(use_gold=True)
    loader = torch.utils.data.DataLoader(dataset,
                                         128,
                                         shuffle=True,
                                         num_workers=0,
                                         collate_fn=collate_samples)
    model = Model(num_classes=len(dataset.type_to_id),
                  num_props=len(dataset.prop_to_id),
                  num_chars=len(dataset.char_to_id)).to(device)

    criterion = nn.BCELoss(reduction='none')
    optim = torch.optim.Adam(model.parameters())
    vat_loss = VATLoss()

    total_it = 0
    for ep in range(3):
        model.train()
        running_loss = 0
        running_acc = 0

        for i, batch in enumerate(loader):
            if total_it % QUERY_FREQ == 0:
                print(
                    f"Querying {len(gold_label_cache)} - {len(gold_label_cache) + QUERY_SIZE}"
                )

                run_query(model, QUERY_SIZE)
                for _ in range(2):
                    finetune(model)  # fine-tune on annotated pairs

            move_batch(batch)

            batch_priors = [
                0.5 + priors.get(n, mean_prior) / 2 for n in batch['names']
            ]
            batch_priors = torch.tensor(batch_priors, dtype=torch.float).view(
                -1, 1).repeat(1, model.num_classes)

            is_gold_label = torch.tensor(
                [x in gold_label_cache for x in batch['names']],
                dtype=torch.float).to(device)
            is_gold_label = is_gold_label.view(-1,
                                               1).repeat(1, model.num_classes)

            batch_priors = torch.max(0.8 * batch_priors, 2. * is_gold_label)

            total_it += 1
            optim.zero_grad()

            xs = model(**batch, feature_only=True)

            lds = vat_loss(model, xs)
            logits = model(xs=xs)
            # MAP noise model
            probs = logits.sigmoid()
            retain_probs = model.simple_transition.sigmoid()
            retain_probs = torch.max(retain_probs, is_gold_label)
            adjusted_probs = retain_probs * probs + \
                (1-retain_probs) * (1 - probs)
            # noise model + prior
            loss = (batch_priors *
                    criterion(adjusted_probs, batch['ys'])).mean() + .1 * lds
            # noise model
            # loss = (criterion(adjusted_probs, batch['ys'])).mean()
            # vanilla
            # loss = (criterion(probs, batch['ys'])).mean() + .1 * lds
            loss.backward()
            optim.step()

            running_loss += loss.item()

            del batch

            if i % LOG_FREQ == LOG_FREQ - 1:
                print(
                    f"Train {i:05d}/{ep:05d}  Loss {running_loss / LOG_FREQ:.4f}  Acc {running_acc / LOG_FREQ: .4f}"
                )
                running_acc = 0
                running_loss = 0.

            if i % VAL_FREQ == VAL_FREQ - 1:
                validate(model)
        validate(model)
        save_path = f'checkpoints/{model_name}_{ep}.pyt'
        print(f'Save to {save_path}')
        model.save(save_path)