Ejemplo n.º 1
0
 def score_fn(x, y=None):
     if args.score_fn == "px":
         return f(x).detach().cpu()
     elif args.score_fn == "py":
         return nn.functional.softmax(f.classify(x), dim=1).max(1)[0].detach().cpu()
     elif args.score_fn == "pxgrad":
         return -torch.log(grad_norm(x).detach().cpu())
     elif args.score_fn == "pxsim":
         assert args.pxycontrast > 0
         dist = smooth_one_hot(y, args.n_classes, args.smoothing)
         output, target, ce_output, neg_num = f.joint(img=x, dist=dist, evaluation=True)
         simloss = nn.CrossEntropyLoss(reduction="none")(output, target)
         simloss = simloss - np.log(neg_num)
         simloss = -1.0 * simloss
         return simloss.detach().cpu()
Ejemplo n.º 2
0
 def sample_q(f,
              replay_buffer,
              y=None,
              n_steps=args.n_steps,
              contrast=False):
     """this func takes in replay_buffer now so we have the option to sample from
     scratch (i.e. replay_buffer==[]).  See test_wrn_ebm.py for example.
     """
     f.eval()
     # get batch size
     bs = args.sgld_batch_size if y is None else y.size(0)
     # generate initial samples and buffer inds of those samples (if buffer is used)
     init_sample, buffer_inds = sample_p_0(replay_buffer, bs=bs, y=y)
     x_k = init_sample.clone()
     x_k.requires_grad = True
     # sgld
     for k in range(n_steps):
         if not contrast:
             energy = f(x_k, y=y).sum()
         else:
             if y is not None:
                 dist = smooth_one_hot(y, args.n_classes, args.smoothing)
             else:
                 dist = torch.ones((bs, args.n_classes)).to(device)
             output, target, ce_output, neg_num = f.joint(img=x_k,
                                                          dist=dist,
                                                          evaluation=True)
             energy = -1.0 * nn.CrossEntropyLoss(reduction="mean")(output,
                                                                   target)
         f_prime = torch.autograd.grad(energy, [x_k], retain_graph=True)[0]
         x_k.data += args.sgld_lr * f_prime + args.sgld_std * torch.randn_like(
             x_k)
     f.train()
     final_samples = x_k.detach()
     # update replay buffer
     if len(replay_buffer) > 0:
         replay_buffer[buffer_inds] = final_samples.cpu()
     return final_samples
Ejemplo n.º 3
0
def train(config,
          fold,
          model,
          dict_loader,
          optimizer,
          scheduler,
          list_dir_save_model,
          dir_pyplot,
          Validation=True,
          Test_flag=True):

    train_loader = dict_loader['train']
    val_loader = dict_loader['val']
    test_loader = dict_loader['test']
    """ loss """
    # criterion_cls = nn.CrossEntropyLoss()
    # criterion_cls = ut.FocalLoss(gamma=st.focal_gamma, alpha=st.focal_alpha, size_average=True)
    # kdloss = ut.KDLoss(4.0)
    criterion_KL = nn.KLDivLoss(reduction="sum")
    criterion_cls = nn.BCELoss()
    # criterion_L1 = nn.L1Loss(reduction='sum').cuda()
    # criterion_L2 = nn.MSELoss(reduction='mean').cuda()
    # criterion_gdl = gdl_loss(pNorm=2).cuda()

    EMS = ut.eval_metric_storage()
    list_selected_EMS = []
    list_ES = []
    for i_tmp in range(len(st.list_standard_eval_dir)):
        list_selected_EMS.append(ut.eval_selected_metirc_storage())
        list_ES.append(
            ut.EarlyStopping(delta=0,
                             patience=st.early_stopping_patience,
                             verbose=True))

    loss_tmp = [0] * 5
    loss_tmp_total = 0
    print('training')
    optimizer.zero_grad()
    optimizer.step()
    """ epoch """
    num_data = len(train_loader.dataset)
    for epoch in range(1, config.num_epochs + 1):
        scheduler.step()
        print(" ")
        print("---------------  epoch {} ----------------".format(epoch))
        """ print learning rate """
        for param_group in optimizer.param_groups:
            print('current LR : {}'.format(param_group['lr']))
        """ batch """
        for i, data_batch in enumerate(train_loader):
            # start = time.time()
            model.train()
            with torch.no_grad():
                """ input"""
                datas = Variable(data_batch['data'].float()).cuda()
                # labels = Variable(data_batch['label'].long()).cuda()
                labels = Variable(data_batch['label'].float()).cuda()
                """ data augmentation """
                ##TODO : flip
                # flip_flag_list = np.random.normal(size=datas.shape[0])>0
                # datas[flip_flag_list] = datas[flip_flag_list].flip(-3)

                ##TODO : translation, cropping
                dict_result = ut.data_augmentation(datas=datas,
                                                   cur_epoch=epoch)
                datas = dict_result['datas']
                translation_list = dict_result['translation_list']
                # aug_dict_result = ut.data_augmentation(datas=aug_datas, cur_epoch=epoch)
                # aug_datas = aug_dict_result['datas']
                """ minmax norm"""
                if st.list_data_norm_type[st.data_norm_type_num] == 'minmax':
                    tmp_datas = datas.view(datas.size(0), -1)
                    tmp_datas -= tmp_datas.min(1, keepdim=True)[0]
                    tmp_datas /= tmp_datas.max(1, keepdim=True)[0]
                    datas = tmp_datas.view_as(datas)
                """ gaussain noise """
                # Gaussian_dist = torch.distributions.normal.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.01]))
                # Gaussian_dist = torch.distributions.normal.Normal(loc=torch.tensor([0.0]), scale=torch.FloatTensor(1).uniform_(0, 0.01))
                # Gaussian_noise = Gaussian_dist.sample(datas.size()).squeeze(-1)
                # datas = datas + Gaussian_noise.cuda()
            """ forward propagation """
            dict_result = model(datas, translation_list)
            output_1 = dict_result['logits']
            output_2 = dict_result['Aux_logits']
            output_3 = dict_result['logitMap']
            output_4 = dict_result['l1_norm']

            #
            loss_list_1 = []
            count_loss = 0
            if fst.flag_loss_1 == True:
                s_labels = ut.smooth_one_hot(labels,
                                             config.num_classes,
                                             smoothing=st.smoothing_img)
                loss_2 = criterion_cls(
                    output_1,
                    s_labels) * st.lambda_major[0] / st.iter_to_update
                loss_list_1.append(loss_2)
                loss_tmp[count_loss] += loss_2.data.cpu().numpy()
                if (EMS.total_train_iter + 1) % st.iter_to_update == 0:
                    EMS.train_aux_loss[count_loss].append(loss_tmp[count_loss])
                    loss_tmp[count_loss] = 0
                count_loss += 1

            if fst.flag_loss_2 == True:
                for i_tmp in range(len(output_2)):
                    s_labels = ut.smooth_one_hot(labels,
                                                 config.num_classes,
                                                 smoothing=st.smoothing_roi)
                    loss_2 = criterion_cls(
                        output_2[i_tmp],
                        s_labels) * st.lambda_aux[i_tmp] / st.iter_to_update
                    loss_list_1.append(loss_2)

                    loss_tmp[count_loss] += loss_2.data.cpu().numpy()
                    if (EMS.total_train_iter + 1) % st.iter_to_update == 0:
                        EMS.train_aux_loss[count_loss].append(
                            loss_tmp[count_loss])
                        loss_tmp[count_loss] = 0
                    count_loss += 1

            if fst.flag_loss_3 == True:
                # patch

                list_loss_tmp = []
                for tmp_j in range(len(output_4)):  # type i.e., patch, roi
                    loss_2 = 0
                    for tmp_i in range(len(output_4[tmp_j])):  # batch
                        tmp_shape = output_4[tmp_j][tmp_i].shape
                        logits = output_4[tmp_j][tmp_i].view(
                            tmp_shape[0], tmp_shape[1], -1)
                        # loss_2 += torch.norm(logits, p=1)
                        loss_2 += torch.norm(logits,
                                             p=1) / (logits.view(-1).size(0))
                    list_loss_tmp.append(
                        (loss_2 / len(output_4[tmp_j]) * st.l1_reg_norm) /
                        st.iter_to_update)
                loss_list_1.append(sum(list_loss_tmp))

                loss_tmp[count_loss] += sum(list_loss_tmp).data.cpu().numpy()
                if (EMS.total_train_iter + 1) % st.iter_to_update == 0:
                    EMS.train_aux_loss[count_loss].append(loss_tmp[count_loss])
                    loss_tmp[count_loss] = 0
                count_loss += 1
            """ L1 reg"""
            # norm = torch.FloatTensor([0]).cuda()
            # for parameter in model.parameters():
            #     norm += torch.norm(parameter, p=1)
            # loss_list_1.append(norm * st.l1_reg)

            loss = sum(loss_list_1)
            loss.backward()
            torch.cuda.empty_cache()
            loss_tmp_total += loss.data.cpu().numpy()

            #TODO :  optimize the model param
            if (EMS.total_train_iter + 1) % st.iter_to_update == 0:
                optimizer.step()
                optimizer.zero_grad()
                """ pyplot """
                EMS.total_train_step += 1
                EMS.train_step.append(EMS.total_train_step)
                EMS.train_loss.append(loss_tmp_total)
                """ print the train loss and tensorboard"""
                if (EMS.total_train_step) % 10 == 0:
                    # print('time : ', time.time() - start)
                    print('Epoch [%d/%d], Step [%d/%d],  Loss: %.4f' %
                          (epoch, config.num_epochs, (i + 1),
                           (num_data // (config.batch_size)), loss_tmp_total))
                loss_tmp_total = 0

            EMS.total_train_iter += 1
            # scheduler.step(epoch + i / len(train_loader))
        """ val """
        if Validation == True:
            print("------------------  val  --------------------------")
            if fst.flag_cropping == True and fst.flag_eval_cropping == True:
                dict_result = ut.eval_classification_model_cropped_input(
                    config, fold, val_loader, model, criterion_cls)
            elif fst.flag_translation == True and fst.flag_eval_translation == True:
                dict_result = ut.eval_classification_model_esemble(
                    config, fold, val_loader, model, criterion_cls)
            elif fst.flag_MC_dropout == True:
                dict_result = ut.eval_classification_model_MC_dropout(
                    config, fold, val_loader, model, criterion_cls)
            else:
                dict_result = ut.eval_classification_model(
                    config, fold, val_loader, model, criterion_cls)
            val_loss = dict_result['Loss']
            acc = dict_result['Acc']
            auc = dict_result['AUC']

            print('Fold : %d, Epoch [%d/%d] val Loss = %f val Acc = %f' %
                  (fold, epoch, config.num_epochs, val_loss, acc))
            """ save the metric """
            EMS.dict_val_metric['val_loss'].append(val_loss)
            EMS.dict_val_metric['val_acc'].append(acc)
            if fst.flag_loss_2 == True:
                for tmp_i in range(len(st.lambda_aux)):
                    EMS.dict_val_metric['val_acc_aux'][tmp_i].append(
                        dict_result['Acc_aux'][tmp_i])
            EMS.dict_val_metric['val_auc'].append(auc)
            EMS.val_step.append(EMS.total_train_step)

            n_stacking_loss_for_selection = 5
            if len(EMS.dict_val_metric['val_loss_queue']
                   ) > n_stacking_loss_for_selection:
                EMS.dict_val_metric['val_loss_queue'].popleft()
            EMS.dict_val_metric['val_loss_queue'].append(val_loss)
            EMS.dict_val_metric['val_mean_loss'].append(
                np.mean(EMS.dict_val_metric['val_loss_queue']))
            """ save model """
            for i_tmp in range(len(list_selected_EMS)):
                save_flag = ut.model_save_through_validation(
                    fold,
                    epoch,
                    EMS=EMS,
                    selected_EMS=list_selected_EMS[i_tmp],
                    ES=list_ES[i_tmp],
                    model=model,
                    dir_save_model=list_dir_save_model[i_tmp],
                    metric_1=st.list_standard_eval[i_tmp],
                    metric_2='',
                    save_flag=False)

        if Test_flag == True:
            print(
                "------------------  test _ test dataset  --------------------------"
            )
            """ load data """
            if fst.flag_cropping == True and fst.flag_eval_cropping == True:
                print("eval : cropping")
                dict_result = ut.eval_classification_model_cropped_input(
                    config, fold, test_loader, model, criterion_cls)
            elif fst.flag_translation == True and fst.flag_eval_translation == True:
                print("eval : assemble")
                dict_result = ut.eval_classification_model_esemble(
                    config, fold, test_loader, model, criterion_cls)
            elif fst.flag_MC_dropout == True:
                dict_result = ut.eval_classification_model_MC_dropout(
                    config, fold, test_loader, model, criterion_cls)
            else:
                print("eval : whole image")
                dict_result = ut.eval_classification_model(
                    config, fold, test_loader, model, criterion_cls)
            acc = dict_result['Acc']
            test_loss = dict_result['Loss']
            """ pyplot """
            EMS.test_acc.append(acc)
            if fst.flag_loss_2 == True:
                for tmp_i in range(len(st.lambda_aux)):
                    EMS.test_acc_aux[tmp_i].append(
                        dict_result['Acc_aux'][tmp_i])
            EMS.test_loss.append(test_loss)
            EMS.test_step.append(EMS.total_train_step)

            print('number of test samples : {}'.format(len(
                test_loader.dataset)))
            print('Fold : %d, Epoch [%d/%d] test Loss = %f test Acc = %f' %
                  (fold, epoch, config.num_epochs, test_loss, acc))
        """ learning rate decay"""
        EMS.LR.append(optimizer.param_groups[0]['lr'])
        # scheduler.step()
        # scheduler.step(val_loss)
        """ plot the chat """
        if epoch % 1 == 0:
            ut.plot_training_info_1(fold,
                                    dir_pyplot,
                                    EMS,
                                    flag='percentile',
                                    flag_match=False)

        ##TODO : early stop only if all of metric has been stopped
        tmp_count = 0
        for i in range(len(list_ES)):
            if list_ES[i].early_stop == True:
                tmp_count += 1
        if tmp_count == len(list_ES):
            break
    """ release the model """
    del model, EMS
    torch.cuda.empty_cache()
Ejemplo n.º 4
0
def main(args):
    # Setup datasets
    dload_train, dload_train_labeled, dload_valid, dload_test = get_data(args)

    # Model and buffer
    sample_q = get_sample_q(args)
    f, replay_buffer = get_model_and_buffer(args, sample_q)

    # Setup Optimizer
    params = f.class_output.parameters() if args.clf_only else f.parameters()
    if args.optimizer == "adam":
        optim = torch.optim.Adam(params,
                                 lr=args.lr,
                                 betas=[0.9, 0.999],
                                 weight_decay=args.weight_decay)
    else:
        optim = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    best_valid_acc = 0.0
    cur_iter = 0
    for epoch in range(args.start_epoch, args.n_epochs):

        # Decay lr
        if epoch in args.decay_epochs:
            for param_group in optim.param_groups:
                new_lr = param_group["lr"] * args.decay_rate
                param_group["lr"] = new_lr

        # Load data
        for i, (x_p_d, _) in tqdm(enumerate(dload_train)):
            # Warmup
            if cur_iter <= args.warmup_iters:
                lr = args.lr * cur_iter / float(args.warmup_iters)
                for param_group in optim.param_groups:
                    param_group["lr"] = lr

            x_p_d = x_p_d.to(device)
            x_lab, y_lab = dload_train_labeled.__next__()
            x_lab, y_lab = x_lab.to(device), y_lab.to(device)

            # Label smoothing
            dist = smooth_one_hot(y_lab, args.n_classes, args.smoothing)

            L = 0.0

            # log p(y|x) cross entropy loss
            if args.pyxce > 0:
                logits = f.classify(x_lab)
                l_pyxce = KHotCrossEntropyLoss()(logits, dist)
                if cur_iter % args.print_every == 0:
                    acc = (logits.max(1)[1] == y_lab).float().mean()
                    print("p(y|x)CE {}:{:>d} loss={:>14.9f}, acc={:>14.9f}".
                          format(epoch, cur_iter, l_pyxce.item(), acc.item()))
                    logger.record_dict({
                        "l_pyxce": l_pyxce.cpu().data.item(),
                        "acc_pyxce": acc.item()
                    })
                L += args.pyxce * l_pyxce

            # log p(x) using sgld
            if args.pxsgld > 0:
                if args.class_cond_p_x_sample:
                    assert not args.uncond, "can only draw class-conditional samples if EBM is class-cond"
                    y_q = torch.randint(0, args.n_classes,
                                        (args.sgld_batch_size, )).to(device)
                    x_q = sample_q(f, replay_buffer, y=y_q)
                else:
                    x_q = sample_q(f, replay_buffer)  # sample from log-sumexp
                fp_all = f(x_p_d)
                fq_all = f(x_q)
                fp = fp_all.mean()
                fq = fq_all.mean()
                l_pxsgld = -(fp - fq)
                if cur_iter % args.print_every == 0:
                    print(
                        "p(x)SGLD | {}:{:>d} loss={:>14.9f} f(x_p_d)={:>14.9f} f(x_q)={:>14.9f}"
                        .format(epoch, i, l_pxsgld, fp, fq))
                    logger.record_dict(
                        {"l_pxsgld": l_pxsgld.cpu().data.item()})
                L += args.pxsgld * l_pxsgld

            # log p(x) using contrastive learning
            if args.pxcontrast > 0:
                # ones like dist to use all indexes
                ones_dist = torch.ones_like(dist).to(device)
                output, target, ce_output, neg_num = f.joint(img=x_lab,
                                                             dist=ones_dist)
                l_pxcontrast = nn.CrossEntropyLoss(reduction="mean")(output,
                                                                     target)
                if cur_iter % args.print_every == 0:
                    acc = (ce_output.max(1)[1] == y_lab).float().mean()
                    print(
                        "p(x)Contrast {}:{:>d} loss={:>14.9f}, acc={:>14.9f}".
                        format(epoch, cur_iter, l_pxcontrast.item(),
                               acc.item()))
                    logger.record_dict({
                        "l_pxcontrast":
                        l_pxcontrast.cpu().data.item(),
                        "acc_pxcontrast":
                        acc.item()
                    })
                L += args.pxycontrast * l_pxcontrast

            # log p(x|y) using sgld
            if args.pxysgld > 0:
                x_q_lab = sample_q(f, replay_buffer, y=y_lab)
                fp, fq = f(x_lab).mean(), f(x_q_lab).mean()
                l_pxysgld = -(fp - fq)
                if cur_iter % args.print_every == 0:
                    print(
                        "p(x|y)SGLD | {}:{:>d} loss={:>14.9f} f(x_p_d)={:>14.9f} f(x_q)={:>14.9f}"
                        .format(epoch, i, l_pxysgld.item(), fp, fq))
                    logger.record_dict(
                        {"l_pxysgld": l_pxysgld.cpu().data.item()})
                L += args.pxsgld * l_pxysgld

            # log p(x|y) using contrastive learning
            if args.pxycontrast > 0:
                output, target, ce_output, neg_num = f.joint(img=x_lab,
                                                             dist=dist)
                l_pxycontrast = nn.CrossEntropyLoss(reduction="mean")(output,
                                                                      target)
                if cur_iter % args.print_every == 0:
                    acc = (ce_output.max(1)[1] == y_lab).float().mean()
                    print(
                        "p(x|y)Contrast {}:{:>d} loss={:>14.9f}, acc={:>14.9f}"
                        .format(epoch, cur_iter, l_pxycontrast.item(),
                                acc.item()))
                    logger.record_dict({
                        "l_pxycontrast":
                        l_pxycontrast.cpu().data.item(),
                        "acc_pxycontrast":
                        acc.item()
                    })
                L += args.pxycontrast * l_pxycontrast

            # SGLD training of log q(x) may diverge
            # break here and record information to restart
            if L.abs().item() > 1e8:
                print("restart epoch: {}".format(epoch))
                print("save dir: {}".format(args.log_dir))
                print("id: {}".format(args.id))
                print("steps: {}".format(args.n_steps))
                print("seed: {}".format(args.seed))
                print("exp prefix: {}".format(args.exp_prefix))
                sys.stdout = sys.__stdout__
                sys.stderr = sys.__stderr__
                print("restart epoch: {}".format(epoch))
                print("save dir: {}".format(args.log_dir))
                print("id: {}".format(args.id))
                print("steps: {}".format(args.n_steps))
                print("seed: {}".format(args.seed))
                print("exp prefix: {}".format(args.exp_prefix))
                assert False, "shit loss explode..."

            optim.zero_grad()
            L.backward()
            optim.step()
            cur_iter += 1

        if epoch % args.plot_every == 0:
            if args.plot_uncond:
                if args.class_cond_p_x_sample:
                    assert not args.uncond, "can only draw class-conditional samples if EBM is class-cond"
                    y_q = torch.randint(0, args.n_classes,
                                        (args.sgld_batch_size, )).to(device)
                    x_q = sample_q(f, replay_buffer, y=y_q)
                    plot(
                        "{}/x_q_{}_{:>06d}.png".format(args.log_dir, epoch, i),
                        x_q)
                    if args.plot_contrast:
                        x_q = sample_q(f, replay_buffer, y=y_q, contrast=True)
                        plot(
                            "{}/contrast_x_q_{}_{:>06d}.png".format(
                                args.log_dir, epoch, i), x_q)
                else:
                    x_q = sample_q(f, replay_buffer)
                    plot(
                        "{}/x_q_{}_{:>06d}.png".format(args.log_dir, epoch, i),
                        x_q)
                    if args.plot_contrast:
                        x_q = sample_q(f, replay_buffer, contrast=True)
                        plot(
                            "{}/contrast_x_q_{}_{:>06d}.png".format(
                                args.log_dir, epoch, i), x_q)
            if args.plot_cond:  # generate class-conditional samples
                y = torch.arange(0, args.n_classes)[None].repeat(
                    args.n_classes,
                    1).transpose(1, 0).contiguous().view(-1).to(device)
                x_q_y = sample_q(f, replay_buffer, y=y)
                plot("{}/x_q_y{}_{:>06d}.png".format(args.log_dir, epoch, i),
                     x_q_y)
                if args.plot_contrast:
                    y = torch.arange(0, args.n_classes)[None].repeat(
                        args.n_classes,
                        1).transpose(1, 0).contiguous().view(-1).to(device)
                    x_q_y = sample_q(f, replay_buffer, y=y, contrast=True)
                    plot(
                        "{}/contrast_x_q_y_{}_{:>06d}.png".format(
                            args.log_dir, epoch, i), x_q_y)

        if args.ckpt_every > 0 and epoch % args.ckpt_every == 0:
            checkpoint(f, replay_buffer, f"ckpt_{epoch}.pt", args)

        if epoch % args.eval_every == 0:
            # Validation set
            correct, val_loss = eval_classification(f, dload_valid)
            if correct > best_valid_acc:
                best_valid_acc = correct
                print("Best Valid!: {}".format(correct))
                checkpoint(f, replay_buffer, "best_valid_ckpt.pt", args)
            # Test set
            correct, test_loss = eval_classification(f, dload_test)
            print("Epoch {}: Valid Loss {}, Valid Acc {}".format(
                epoch, val_loss, correct))
            print("Epoch {}: Test Loss {}, Test Acc {}".format(
                epoch, test_loss, correct))
            f.train()
            logger.record_dict({
                "Epoch":
                epoch,
                "Valid Loss":
                val_loss,
                "Valid Acc":
                correct.detach().cpu().numpy(),
                "Test Loss":
                test_loss,
                "Test Acc":
                correct.detach().cpu().numpy(),
                "Best Valid":
                best_valid_acc.detach().cpu().numpy(),
                "Loss":
                L.cpu().data.item(),
            })
        checkpoint(f, replay_buffer, "last_ckpt.pt", args)

        logger.dump_tabular()