Example #1
0
def main():
    args = parse_args()
    args.dsbn = True if 'dsbn' in args.model_name else False  # set dsbn
    args.source_dataset = '|'.join(args.source_datasets)
    args.target_dataset = '|'.join(args.target_datasets)
    torch.cuda.set_device(args.gpu)  # set current gpu device id so pin_momory works on the target gpu
    start_time = datetime.datetime.now()  # execution start time

    # make save_dir
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    # create log file
    log_filename = 'train_records.log'
    log_path = os.path.join(args.save_dir, log_filename)
    logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO,
                                 print_level=logging.INFO if args.print_console else None,
                                 mode='a' if args.resume else 'w')

    # set num_classes by checking exp_setting
    if args.num_classes == 0:
        if args.exp_setting == 'digits':
            logger.warning('num_classes are not 10! set to 10.')
            args.num_classes = 10
        elif args.exp_setting == 'office':
            logger.warning('num_classes are not 31! set to 31.')
            args.num_classes = 31
        elif args.exp_setting in ['visda', 'imageclef']:
            logger.warning('num_classes are not 12! set to 12.')
            args.num_classes = 12
        elif args.exp_setting in ['office-home']:
            logger.warning('num_classes are not 65! set to 65.')
            args.num_classes = 65
        elif args.exp_setting in ['office-caltech']:
            args.num_classes = 10
        else:
            raise AttributeError('Wrong num_classes: {}'.format(args.num_classes))

    if args.weight_irm > 0.0:
        args.weight_source_irm = args.weight_irm
        args.weight_target_irm = args.weight_irm

    if(args.iters_active_irm > 0):
        weight_irm_backup = [args.weight_source_irm, args.weight_target_irm]
        args.weight_source_irm = 0
        args.weight_target_irm = 0

    if args.manual_seed:
        # set manual seed
        args.manual_seed = np.uint32(args.manual_seed)
        torch.manual_seed(args.manual_seed)
        torch.cuda.manual_seed(args.manual_seed)
        random.seed(args.manual_seed)
        np.random.seed(args.manual_seed)
        logger.info('Random Seed: {}'.format(int(args.manual_seed)))
        args.random_seed = args.manual_seed  # save seed into args
    else:
        seed = np.uint32(random.randrange(sys.maxsize))
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)
        np.random.seed(np.uint32(seed))
        logger.info('Random Seed: {}'.format(seed))
        args.random_seed = seed  # save seed into args

    if args.resume:
        logger.info('Resume training')
    else:
        logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4))  # print args
    torch.save(vars(args), os.path.join(args.save_dir, 'args_dict.pth'))  # save args

    num_classes = args.num_classes
    in_features = args.in_features if args.in_features != 0 else num_classes
    num_domains = len(args.source_datasets) + len(args.target_datasets)
    if args.merge_sources:
        num_source_domains = 1
    else:
        num_source_domains = len(args.source_datasets)
    num_target_domains = len(args.target_datasets)

    # tfboard
    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        tfboard_dir = os.path.join(args.save_dir, 'tfboard')
        if not os.path.isdir(tfboard_dir):
            os.makedirs(tfboard_dir)
        writer = SummaryWriter(tfboard_dir)

    # resume
    if args.resume:
        try:
            checkpoints = io_utils.load_latest_checkpoints(args.save_dir, args, logger)
        except FileNotFoundError:
            logger.warning('Latest checkpoints are not found! Trying to load best model...')
            checkpoints = io_utils.load_best_checkpoints(args.save_dir, args, logger)

        start_iter = checkpoints[0]['iteration'] + 1
    else:
        start_iter = 1

    ###################################################################################################################
    #                                               Data Loading                                                      #
    ###################################################################################################################

    source_train_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, source_name, 'train', args.jitter))
                             for source_name in args.source_datasets]
    target_train_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_name, 'train', args.jitter))
                             for target_name in args.target_datasets]

    if args.merge_sources:
        for i in range(len(source_train_datasets)):
            if i == 0:
                merged_source_train_datasets = source_train_datasets[i]
            else:
                # concatenate dataset
                merged_source_train_datasets = merged_source_train_datasets + source_train_datasets[i]
        source_train_datasets = [merged_source_train_datasets]

    # dataloader
    source_train_dataloaders = [data.DataLoader(source_train_dataset, batch_size=args.batch_size, shuffle=True,
                                                num_workers=args.num_workers, drop_last=True, pin_memory=True)
                                for source_train_dataset in source_train_datasets]
    target_train_dataloaders = [data.DataLoader(target_train_dataset, batch_size=args.batch_size, shuffle=True,
                                                num_workers=args.num_workers, drop_last=True, pin_memory=True)
                                for target_train_dataset in target_train_datasets]

    source_train_dataloader_iters = [enumerate(source_train_dataloader) for source_train_dataloader in
                                     source_train_dataloaders]
    target_train_dataloader_iters = [enumerate(target_train_dataloader) for target_train_dataloader in
                                     target_train_dataloaders]

    # validation dataloader
    target_val_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_name, 'val', args.jitter))
                           for target_name in args.target_datasets]
    target_val_dataloaders = [data.DataLoader(target_val_dataset, batch_size=args.batch_size,
                                              shuffle=False, num_workers=args.num_workers, pin_memory=True)
                              for target_val_dataset in target_val_datasets]

    ###################################################################################################################
    #                                               Model Loading                                                     #
    ###################################################################################################################
    model = get_model(args.model_name, args.num_classes, args.in_features, num_domains=num_domains, pretrained=True)

    model.train(True)
    if args.resume:
        model.load_state_dict(checkpoints[0]['model'])
    elif args.init_model_path:
        init_checkpoint = torch.load(args.init_model_path)
        model.load_state_dict(init_checkpoint['model'])
    model = model.cuda(args.gpu)

    params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay,
                                  double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor)

    if args.adv_loss:
        discriminators = [get_discriminator(args.exp_setting,
                                            in_features=args.in_features if args.in_features != 0 else args.num_classes)
                          for _ in range(num_target_domains) for _ in range(num_source_domains)]
        discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators]
        D_params = get_optimizer_params(discriminators, args.learning_rate, weight_decay=args.weight_decay,
                                        double_bias_lr=args.double_bias_lr, base_weight_factor=None)
        if args.resume:
            if checkpoints[1]:
                for d_idx, discriminator in enumerate(discriminators):
                    discriminator.load_state_dict(checkpoints[1]['discriminators'][d_idx])

    if args.sm_loss:
        srcs_centroids = [Centroids(in_features, num_classes) for _ in range(num_source_domains)]
        trgs_centroids = [Centroids(in_features, num_classes) for _ in range(num_target_domains)]

        if args.resume:
            if checkpoints[2]:
                for src_idx, src_centroids in enumerate(srcs_centroids):
                    src_centroids.load_state_dict(checkpoints[2]['srcs_centroids'][src_idx])
                for trg_idx, trg_centroids in enumerate(trgs_centroids):
                    trg_centroids.load_state_dict(checkpoints[2]['trgs_centroids'][trg_idx])

        srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids]
        trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids]

    ###################################################################################################################
    #                                               Train Configurations                                              #
    ###################################################################################################################
    ce_loss = nn.CrossEntropyLoss()
    bce_loss = nn.BCEWithLogitsLoss()
    # mse_loss = nn.MSELoss()

    lr_scheduler = LRScheduler(args.learning_rate, args.warmup_learning_rate, args.warmup_step,
                               num_steps=args.max_step,
                               alpha=10, beta=0.75, double_bias_lr=args.double_bias_lr,
                               base_weight_factor=args.base_weight_factor)

    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(params, momentum=0.9, nesterov=True)
    else:
        optimizer = optim.Adam(params, betas=(args.beta1, args.beta2))

    if args.resume:
        if checkpoints[1]:
            optimizer.load_state_dict(checkpoints[1]['optimizer'])

    if args.adv_loss:
        if args.optimizer.lower() == 'sgd':
            optimizer_D = optim.SGD(D_params, momentum=0.9, nesterov=True)
        else:
            optimizer_D = optim.Adam(D_params, betas=(args.beta1, args.beta2))

        if args.resume:
            if checkpoints[1]:
                optimizer_D.load_state_dict(checkpoints[1]['optimizer_D'])

    # Train Starts
    logger.info('Train Starts')
    domain_loss_adjust_factor = args.domain_loss_adjust_factor

    monitor = Monitor()

    global best_accuracy
    global best_accuracies_each_c
    global best_mean_val_accuracies
    global best_total_val_accuracies
    best_accuracy = 0.0
    best_accuracies_each_c = []
    best_mean_val_accuracies = []
    best_total_val_accuracies = []

    for i_iter in range(start_iter, args.early_stop_step + 1):

        if(args.iters_active_irm > 0):
            if(i_iter > args.iters_active_irm):
                args.weight_source_irm = weight_irm_backup[0]
                args.weight_target_irm = weight_irm_backup[1]

        src_inputs = []
        for src_dataloader_idx in range(len(source_train_dataloader_iters)):
            try:
                _, (x_s, y_s) = source_train_dataloader_iters[src_dataloader_idx].__next__()
                src_inputs.append((x_s, y_s))
            except StopIteration:
                source_train_dataloader_iters[src_dataloader_idx] = enumerate(
                    source_train_dataloaders[src_dataloader_idx])
                _, (x_s, y_s) = source_train_dataloader_iters[src_dataloader_idx].__next__()
                src_inputs.append((x_s, y_s))

        trg_inputs = []
        for trg_dataloader_idx in range(len(target_train_dataloader_iters)):
            try:
                _, (x_t, _) = target_train_dataloader_iters[trg_dataloader_idx].__next__()
                trg_inputs.append((x_t, None))
            except StopIteration:
                target_train_dataloader_iters[trg_dataloader_idx] = enumerate(
                    target_train_dataloaders[trg_dataloader_idx])
                _, (x_t, _) = target_train_dataloader_iters[trg_dataloader_idx].__next__()
                trg_inputs.append((x_t, None))

        current_lr = lr_scheduler.current_lr(i_iter)
        adaptation_lambda = adaptation_factor((i_iter - args.warmup_step) / float(args.max_step),
                                              gamma=args.adaptation_gamma)
        # init optimizer
        optimizer.zero_grad()
        lr_scheduler(optimizer, i_iter)
        if args.adv_loss:
            optimizer_D.zero_grad()
            lr_scheduler(optimizer_D, i_iter)

        ########################################################################################################
        #                                               Train G                                                #
        ########################################################################################################
        if args.adv_loss:
            for discriminator in discriminators:
                for param in discriminator.parameters():
                    param.requires_grad = False
        # ship to cuda
        src_inputs = [(x_s.cuda(args.gpu), y_s.cuda(args.gpu)) for (x_s, y_s) in src_inputs]
        trg_inputs = [(x_t.cuda(args.gpu), None) for (x_t, _) in trg_inputs]

        if args.dsbn:
            src_preds = []
            for src_idx, (x_s, y_s) in enumerate(src_inputs):
                pred_s, f_s = model(x_s, src_idx * torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu),
                                    with_ft=True)
                src_preds.append((pred_s, f_s))

            trg_preds = []
            for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
                pred_t, f_t = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu),
                                    with_ft=True)
                trg_preds.append((pred_t, f_t))
        else:
            src_preds = []
            for src_idx, (x_s, y_s) in enumerate(src_inputs):
                pred_s, f_s = model(x_s, with_ft=True)
                src_preds.append((pred_s, f_s))

            trg_preds = []
            for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
                pred_t, f_t = model(x_t, with_ft=True)
                trg_preds.append((pred_t, f_t))

        Closs_src = 0
        Closs_src_irm = 0
        for (_, y_s), (pred_s, f_s) in zip(src_inputs, src_preds):
            Closs_src = Closs_src + ce_loss(pred_s, y_s) / float(num_source_domains)

            if(args.weight_source_irm > 0):
                Closs_src_irm += feature_penalty(f_s, model.fc, ce_loss, y_s)
                

        monitor.update({"Loss/Closs_src": float(Closs_src)})
        Floss = Closs_src

        if(args.weight_source_irm > 0):
            Floss += Closs_src_irm * args.weight_source_irm
            monitor.update({"Loss/Closs_src_irm": float(Closs_src_irm)})

        if args.adv_loss:
            # adversarial loss
            Gloss = 0
            for trg_idx, (_, f_t) in enumerate(trg_preds):
                for src_idx, (_, f_s) in enumerate(src_preds):
                    Dout_s = discriminators[trg_idx * num_source_domains + src_idx](f_s)
                    source_label = torch.zeros_like(Dout_s).cuda(args.gpu)
                    loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label) / 2

                    Dout_t = discriminators[trg_idx * num_source_domains + src_idx](f_t)
                    target_label = torch.ones_like(Dout_t).cuda(args.gpu)
                    loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label) / 2
                    Gloss = Gloss - (loss_adv_src + loss_adv_trg)
            Gloss = Gloss / float(num_target_domains * num_source_domains)
            monitor.update({'Loss/Gloss': float(Gloss)})

            Floss = Floss + adaptation_lambda * Gloss

        # # pseudo label generation
        # pred_t_pseudos = []
        # if args.dsbn:
        #     with torch.no_grad():
        #         model.eval()
        #         for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
        #             pred_t_pseudo = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu),
        #                                   with_ft=False)
        #             pred_t_pseudos.append(pred_t_pseudo)
        #         model.train(True)
        # else:
        #     with torch.no_grad():
        #         model.eval()
        #         for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
        #             pred_t_pseudo = model(x_t, with_ft=False)
        #             pred_t_pseudos.append(pred_t_pseudo)
        #         model.train(True)

        # pseudo label generation
        pred_t_pseudos = []
        if args.dsbn:
            with torch.no_grad():
                model.eval()
                for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
                    pred_t_pseudo = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu),
                                        with_ft=False)
                    pred_t_pseudos.append(pred_t_pseudo)
                model.train(True)
        else:
            with torch.no_grad():
                model.eval()
                for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
                    pred_t_pseudo = model(x_t, with_ft=False)
                    pred_t_pseudos.append(pred_t_pseudo)
                model.train(True)

        if(args.weight_target_irm > 0):
            Closs_trg_irm = 0
            #Closs_trg = 0
            for pred_t_pseudo, (pred_t, f_t) in zip(pred_t_pseudos, trg_preds):
                y_t_pseudo = torch.argmax(pred_t_pseudo, 1).detach()
                #Closs_trg = Closs_trg + ce_loss(pred_t, y_t_pseudo)
                Closs_trg_irm += feature_penalty(f_t, model.fc, ce_loss, y_t_pseudo)

            #Floss += Closs_trg
            Floss += Closs_trg_irm * args.weight_target_irm
            
            monitor.update({"Loss/Closs_trg_irm": float(Closs_trg_irm)})
            #monitor.update({"Loss/Closs_trg": float(Closs_trg_irm)})

        # moving semantic loss
        if args.sm_loss:
            current_srcs_centroids = [src_centroids(f_s, y_s) for src_centroids, (x_s, y_s), (_, f_s) in
                                      zip(srcs_centroids, src_inputs, src_preds)]

            current_trgs_centroids = [trg_centroids(f_t, torch.argmax(pred_t_pseudo, 1).detach()) for
                                      trg_centroids, pred_t_pseudo, (_, f_t) in
                                      zip(trgs_centroids, pred_t_pseudos, trg_preds)]

            semantic_loss = 0
            for current_trg_centroids in current_trgs_centroids:
                for current_src_centroids in current_srcs_centroids:
                    semantic_loss = semantic_loss + args.sm_etha * semantic_loss_calc(current_src_centroids,
                                                                                      current_trg_centroids)
            semantic_loss = semantic_loss / float(num_target_domains * num_source_domains)
            monitor.update({'Loss/SMloss': float(semantic_loss)})

            Floss = Floss + adaptation_lambda * semantic_loss

   
        # Floss backward
        Floss.backward()
        optimizer.step()
        ########################################################################################################
        #                                               Train D                                                #
        ########################################################################################################
        if args.adv_loss:
            for discriminator in discriminators:
                for param in discriminator.parameters():
                    param.requires_grad = True

        if args.adv_loss:
            # adversarial loss
            Dloss = 0
            for trg_idx, (_, f_t) in enumerate(trg_preds):
                for src_idx, (_, f_s) in enumerate(src_preds):
                    Dout_s = discriminators[trg_idx * num_source_domains + src_idx](f_s.detach())
                    source_label = torch.zeros_like(Dout_s).cuda(args.gpu)
                    loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label) / 2

                    # target
                    Dout_t = discriminators[trg_idx * num_source_domains + src_idx](f_t.detach())
                    target_label = torch.ones_like(Dout_t).cuda(args.gpu)
                    loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label) / 2
                    Dloss = Dloss + loss_adv_src + loss_adv_trg
            Dloss = Dloss / float(num_target_domains * num_source_domains)
            monitor.update({'Loss/Dloss': float(Dloss)})
            Dloss = adaptation_lambda * Dloss
            Dloss.backward()
            optimizer_D.step()

        if args.sm_loss:
            for src_centroids, current_src_centroids in zip(srcs_centroids, current_srcs_centroids):
                src_centroids.centroids.data = current_src_centroids.data
            for trg_centroids, current_trg_centroids in zip(trgs_centroids, current_trgs_centroids):
                trg_centroids.centroids.data = current_trg_centroids.data

        if i_iter % args.disp_interval == 0 and i_iter != 0:
            disp_msg = 'iter[{:8d}/{:8d}], '.format(i_iter, args.early_stop_step)
            disp_msg += str(monitor)
            if args.adv_loss or args.sm_loss:
                disp_msg += ', lambda={:.6f}'.format(adaptation_lambda)
            disp_msg += ', lr={:.6f}'.format(current_lr)
            logger.info(disp_msg)

            if args.use_tfboard:
                if args.save_model_hist:
                    for name, param in model.named_parameters():
                        writer.add_histogram(name, param.cpu().data.numpy(), i_iter, bins='auto')

                for k, v in monitor.losses.items():
                    writer.add_scalar(k, v, i_iter)
                if args.adv_loss or args.sm_loss:
                    writer.add_scalar('adaptation_lambda', adaptation_lambda, i_iter)
                writer.add_scalar('learning rate', current_lr, i_iter)
            monitor.reset()

        if i_iter % args.save_interval == 0 and i_iter != 0:
            logger.info("Elapsed Time: {}".format(datetime.datetime.now() - start_time))
            logger.info("Start Evaluation at {:d}".format(i_iter))

            target_val_dataloader_iters = [enumerate(target_val_dataloader)
                                           for target_val_dataloader in target_val_dataloaders]

            total_val_accuracies = []
            mean_val_accuracies = []
            val_accuracies_each_c = []
            model.eval()  # evaluation mode
            for trg_idx, target_val_dataloader_iter in enumerate(target_val_dataloader_iters, num_source_domains):
                pred_vals = []
                y_vals = []
                x_val = None
                y_val = None
                pred_val = None
                with torch.no_grad():
                    for i, (x_val, y_val) in target_val_dataloader_iter:
                        y_vals.append(y_val.cpu())
                        x_val = x_val.cuda(args.gpu)
                        y_val = y_val.cuda(args.gpu)

                        if args.dsbn:
                            pred_val = model(x_val, trg_idx * torch.ones_like(y_val), with_ft=False)
                        else:
                            pred_val = model(x_val, with_ft=False)

                        pred_vals.append(pred_val.cpu())

                pred_vals = torch.cat(pred_vals, 0)
                y_vals = torch.cat(y_vals, 0)
                total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])

                val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                               class_idx=c, topk=(1,))[0]))
                                       for c, c_name in
                                       enumerate(target_val_datasets[trg_idx - num_source_domains].classes)]
                logger.info('\n{} Accuracy of Each class\n'.format(args.target_datasets[trg_idx - num_source_domains]) +
                            ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                                     for c_name, c_val_acc in val_accuracy_each_c]))
                mean_val_accuracy = float(
                    torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))

                logger.info('{} mean Accuracy: {:.2f}%'.format(
                    args.target_datasets[trg_idx - num_source_domains], 100 * mean_val_accuracy))
                logger.info(
                    '{} Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx - num_source_domains],
                                                  total_val_accuracy * 100))

                total_val_accuracies.append(total_val_accuracy)
                val_accuracies_each_c.append(val_accuracy_each_c)
                mean_val_accuracies.append(mean_val_accuracy)

                if args.use_tfboard:
                    writer.add_scalar('Val_acc', total_val_accuracy, i_iter)
                    for c_name, c_val_acc in val_accuracy_each_c:
                        writer.add_scalar('Val_acc_of_{}'.format(c_name), c_val_acc)
            model.train(True)  # train mode

            if args.exp_setting.lower() == 'visda':
                val_accuracy = float(torch.mean(torch.FloatTensor(mean_val_accuracies)))
            else:
                val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies)))

            # for memory
            del x_val, y_val, pred_val, pred_vals, y_vals
            for target_val_dataloader_iter in target_val_dataloader_iters:
                del target_val_dataloader_iter
            del target_val_dataloader_iters

            if val_accuracy > best_accuracy:
                # save best model
                best_accuracy = val_accuracy
                best_accuracies_each_c = val_accuracies_each_c
                best_mean_val_accuracies = mean_val_accuracies
                best_total_val_accuracies = total_val_accuracies
                options = io_utils.get_model_options_from_args(args, i_iter)
                # dict to save
                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}
                if args.adv_loss:
                    optimizer_dict.update({'optimizer_D': optimizer_D.state_dict(),
                                           'discriminators': [discriminator.cpu().state_dict()
                                                              for discriminator in discriminators],
                                           'source_datasets': args.source_datasets,
                                           'target_datasets': args.target_datasets})
                centroids_dict = {}
                if args.sm_loss:
                    centroids_dict = {
                        'srcs_centroids': [src_centroids.cpu().state_dict() for src_centroids in srcs_centroids],
                        'trgs_centroids': [trg_centroids.cpu().state_dict() for trg_centroids in trgs_centroids]}
                # save best checkpoint
                io_utils.save_checkpoints(args.save_dir, options, i_iter, model_dict, optimizer_dict, centroids_dict,
                                          logger, best=True)
                # ship to cuda
                model = model.cuda(args.gpu)
                if args.adv_loss:
                    discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators]
                if args.sm_loss:
                    srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids]
                    trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids]

                # save best result into textfile
                contents = [' '.join(sys.argv) + '\n',
                            "best accuracy: {:.2f}%\n".format(best_accuracy)]
                for d_idx in range(num_target_domains):
                    best_accuracy_each_c = best_accuracies_each_c[d_idx]
                    best_mean_val_accuracy = best_mean_val_accuracies[d_idx]
                    best_total_val_accuracy = best_total_val_accuracies[d_idx]
                    contents.extend(["{}2{}\n".format(args.source_dataset, args.target_datasets[d_idx]),
                                     "best total acc: {:.2f}%\n".format(100 * best_total_val_accuracy),
                                     "best mean acc: {:.2f}%\n".format(100 * best_mean_val_accuracy),
                                     'Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc)
                                                              for _, c_val_acc in best_accuracy_each_c]) + '\n'])

                best_result_path = os.path.join('./output', '{}_best_result.txt'.format(
                    os.path.splitext(os.path.basename(__file__))[0]))
                with open(best_result_path, 'a+') as f:
                    f.writelines(contents)

            # logging best model results
            for trg_idx in range(num_target_domains):
                best_accuracy_each_c = best_accuracies_each_c[trg_idx]
                best_total_val_accuracy = best_total_val_accuracies[trg_idx]
                best_mean_val_accuracy = best_mean_val_accuracies[trg_idx]
                logger.info(
                    '\nBest {} Accuracy of Each class\n'.format(args.target_datasets[trg_idx]) +
                    ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                             for c_name, c_val_acc in best_accuracy_each_c]))
                logger.info('Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc)
                                                     for _, c_val_acc in best_accuracy_each_c]))
                logger.info('Best {} mean Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx],
                                                                    100 * best_mean_val_accuracy))
                logger.info('Best {} Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx],
                                                               100 * best_total_val_accuracy))
            logger.info("Best model's Average Accuracy of targets: {:.2f}".format(100 * best_accuracy))

            if args.save_ckpts:
                # get options
                options = io_utils.get_model_options_from_args(args, i_iter)
                # dict to save
                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}
                if args.adv_loss:
                    optimizer_dict.update({'optimizer_D': optimizer_D.state_dict(),
                                           'discriminators': [discriminator.cpu().state_dict()
                                                              for discriminator in discriminators]})
                centroids_dict = {}
                if args.sm_loss:
                    centroids_dict = {
                        'srcs_centroids': [src_centroids.cpu().state_dict() for src_centroids in srcs_centroids],
                        'trgs_centroids': [trg_centroids.cpu().state_dict() for trg_centroids in trgs_centroids]}
                # save checkpoint
                io_utils.save_checkpoints(args.save_dir, options, i_iter, model_dict, optimizer_dict, centroids_dict,
                                          logger, best=False)

                # ship to cuda
                model = model.cuda(args.gpu)
                if args.adv_loss:
                    discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators]
                if args.sm_loss:
                    srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids]
                    trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids]

    if args.use_tfboard:
        writer.close()

    logger.info('Total Time: {}'.format((datetime.datetime.now() - start_time)))
Example #2
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu)  # set current gpu device id so pin_momory works on the target gpu
    if not os.path.isfile(args.model_path):
        raise IOError("ERROR model_path: {}".format(args.model_path))

    # load checkpoints
    checkpoint = torch.load(args.model_path)
    global_step = checkpoint['iteration']
    model_state_dict = checkpoint['model']

    # set logger
    model_dir = os.path.dirname(args.model_path)
    log_filename = 'evaluation_step{}.log'.format(global_step)
    log_path = os.path.join(model_dir, log_filename)
    logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO,
                                 print_level=logging.INFO if args.print_console else None)

    # set num_classes by checking exp_setting
    if args.num_classes == 0:
        if args.exp_setting == 'digits':
            logger.warning('num_classes are not 10! set to 10.')
            args.num_classes = 10
        elif args.exp_setting == 'office':
            logger.warning('num_classes are not 31! set to 31.')
            args.num_classes = 31
        elif args.exp_setting in ['visda', 'imageclef']:
            logger.warning('num_classes are not 12! set to 12.')
            args.num_classes = 12
        elif args.exp_setting in ['office-home']:
            logger.warning('num_classes are not 65! set to 65.')
            args.num_classes = 65
        elif args.exp_setting in ['office-caltech']:
            args.num_classes = 10
        else:
            raise AttributeError('Wrong num_classes: {}'.format(args.num_classes))

    # update model args from filename
    model_args = io_utils.get_model_args_dict_from_filename(os.path.basename(args.model_path))
    model_args['source_datasets'] = model_args['source_dataset'].split('|')
    model_args['target_datasets'] = model_args['target_dataset'].split('|')
    args.__dict__.update(model_args)
    # load args if it exists
    args_path = os.path.join(model_dir, 'args_dict.pth')
    if os.path.isfile(args_path):
        logger.info('Arguemnt file exist. load arguments from {}'.format(args_path))
        args_dict = torch.load(args_path)
        update_dict = {'args_path': args_path,
                       'source_dataset': args_dict['source_dataset'],
                       'source_datasets': args_dict['source_datasets'],
                       'target_dataset': args_dict['target_dataset'],
                       'target_datasets': args_dict['target_datasets'],
                       'model_name': args_dict['model_name'],
                       'in_features': args_dict['in_features'], }
        args.__dict__.update(update_dict)
    args.dsbn = True if 'dsbn' in args.model_name else False  # set dsbn
    logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4))

    model_options = io_utils.get_model_options_from_args(args, global_step)

    batch_size = args.batch_size
    num_classes = args.num_classes
    num_source_domains = len(args.source_datasets)
    num_target_domains = len(args.target_datasets)


    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        base_dir = os.path.dirname(args.model_path)
        tfboard_dir = os.path.join(base_dir, 'tfboard')
        if not os.path.isdir(tfboard_dir):
            os.makedirs(tfboard_dir)
        writer = SummaryWriter(tfboard_dir)
    ###################################################################################################################
    #                                               Data Loading                                                      #
    ###################################################################################################################

    source_test_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, source_dataset, 'test', args.jitter))
                            for source_dataset in args.source_datasets]
    target_test_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_dataset, 'test', args.jitter))
                            for target_dataset in args.target_datasets]

    ###################################################################################################################
    #                                               Model Loading                                                     #
    ###################################################################################################################
    model = get_model(args.model_name, args.num_classes, args.in_features, pretrained=False)

    logger.info('Load trained parameters...')
    model.load_state_dict(model_state_dict)
    model.train(False)
    model.eval()
    model = model.cuda(args.gpu)

    # tfboard: write centroids
    if args.use_tfboard:
        centroids_filename = io_utils.get_centroids_filename(model_options)
        centroids_path = os.path.join(model_dir, centroids_filename)
        if os.path.isfile(centroids_path):
            logger.info('write centroids on tfboard: {}'.format(centroids_path))
            centroids_ckpt = torch.load(centroids_path)

            for i, centroids in enumerate(centroids_ckpt['srcs_centroids']):
                src_centroids = centroids['centroids'].cpu().data.numpy()
                writer.add_embedding(src_centroids, metadata=list(range(num_classes)),
                                     tag='src_centroids_{}'.format(args.source_datasets[i]), global_step=global_step)

            trg_centroids = centroids_ckpt['trg_centroids']['centroids'].cpu().data.numpy()
            writer.add_embedding(trg_centroids, metadata=list(range(num_classes)),
                                 tag='trg_centroids', global_step=global_step)

    logger.info('Start Evaluation')
    results = {'step': global_step}
    total_features = []
    total_labels = []

    # for d_idx, dataset in enumerate(target_test_datasets + source_test_datasets):
    for d_idx, dataset in enumerate(target_test_datasets):
        # dataloader
        dataloader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
                                     num_workers=args.num_workers, drop_last=False, pin_memory=True)
        pred_vals = []
        y_vals = []
        if args.use_tfboard:
            features = []

        with torch.no_grad():
            for i, (x_val, y_val) in enumerate(dataloader):
                x_val = x_val.cuda(args.gpu)
                y_val = y_val.cuda(args.gpu)

                if args.dsbn:
                    pred_val, f_val = model(x_val, torch.zeros_like(y_val), with_ft=True)
                else:
                    pred_val, f_val = model(x_val, with_ft=True)

                pred_vals.append(pred_val.cpu())
                y_vals.append(y_val.cpu())
                if args.use_tfboard:
                    features += [f_val.cpu().data.numpy()]

        pred_vals = torch.cat(pred_vals, 0)
        y_vals = torch.cat(y_vals, 0)
        test_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])
        val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                       class_idx=c, topk=(1,))[0]))
                               for c, c_name in enumerate(dataset.classes)]
        # logging
        if d_idx <= num_target_domains:
            logger.info('{} Test Accuracy: {:.4f}%'.format(args.target_datasets[d_idx], 100 * test_accuracy))
            logger.info('\nEach class Accuracy of {}\n'.format(args.target_datasets[d_idx]) +
                        ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                                 for c_name, c_val_acc in val_accuracy_each_c]))
            logger.info('Evaluation mean Accuracy: {:.2f}%'.format(
                100 * float(torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))))
            if args.save_results:
                results.update({args.target_datasets[d_idx]: test_accuracy})
                results.update(
                    {args.target_datasets[d_idx] + '_' + c_name: c_val_acc for c_name, c_val_acc in val_accuracy_each_c})
        else:
            logger.info('{} Test Accuracy: {:.4f}'.format(args.source_datasets[d_idx - num_target_domains], test_accuracy))
            logger.info('\nEach class Accuracy of {}\n'.format(args.source_datasets[d_idx - num_target_domains]) +
                        ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                                 for c_name, c_val_acc in val_accuracy_each_c]))
            logger.info('Evaluation mean Accuracy: {:.2f}%'.format(
                100 * float(torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))))
            if args.save_results:
                results.update({args.source_datasets[d_idx-num_target_domains]: test_accuracy})
                results.update(
                    {args.source_datasets[d_idx - num_target_domains] + '_' + c_name: c_val_acc for c_name, c_val_acc in
                     val_accuracy_each_c})

        if args.use_tfboard:
            features = np.concatenate(features, axis=0)
            y_vals_numpy = y_vals.numpy().astype(np.int)
            embed_features = features
            # u, s, vt = np.linalg.svd(features)
            # embed_features = np.dot(features, vt[:3, :].transpose())

            if d_idx <= num_target_domains:
                total_features += [embed_features]
                total_labels += [args.target_datasets[d_idx][0] + str(int(l)) for l in y_vals]
                writer.add_embedding(embed_features, metadata=y_vals_numpy, tag=args.target_datasets[d_idx],
                                     global_step=global_step)
            else:
                total_features += [embed_features]
                total_labels += [args.source_datasets[d_idx-num_target_domains][0] + str(int(l)) for l in y_vals]
                writer.add_embedding(embed_features, metadata=y_vals_numpy, tag=args.source_datasets[d_idx - num_target_domains],
                                     global_step=global_step)

    if args.use_tfboard:
        total_features = np.concatenate(total_features, axis=0)
        writer.add_embedding(total_features, metadata=list(total_labels),
                             tag='feat_embed_S:{}_T:{}'.format(args.source_dataset, args.target_dataset),
                             global_step=global_step)

    # save results
    if args.save_results:
        result_filename = 'evaluation_{:06d}.pth'.format(global_step)
        torch.save(results, os.path.join(model_dir, result_filename))

    if args.use_tfboard:
        writer.close()
Example #3
0
def main():
    print('start finetune')
    args = parse_args()
    args.dsbn = True if 'dsbn' in args.model_name else False  # set dsbn
    args.cpua = True if 'cpua' in args.model_name else False

    torch.cuda.set_device(args.gpu)  # set current gpu device id so pin_momory works on the target gpu
    start_time = datetime.datetime.now()  # execution start time

    # make save_dir
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    # check whether teacher model exists
    if not os.path.isfile(args.teacher_model_path):
        raise AttributeError('Missing teacher model path: {}'.format(args.teacher_model_path))

    # create log file
    log_filename = 'train_records.log'
    log_path = os.path.join(args.save_dir, log_filename)
    logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO,
                                 print_level=logging.INFO if args.print_console else None,
                                 mode='a' if args.resume else 'w')

    # set num_classes by checking exp_setting
    if args.num_classes == 0:
        if args.exp_setting in ['office-home']:
            logger.warning('num_classes are not 65! set to 65.')
            args.num_classes = 65
        elif args.exp_setting in ['digits']:
            logger.warning('num_classes are not 10! set to 10.')
            args.num_classes = 10
        else:
            raise AttributeError('Wrong num_classes: {}'.format(args.num_classes))

    if args.manual_seed:
        # set manual seed
        args.manual_seed = np.uint32(args.manual_seed)
        torch.manual_seed(args.manual_seed)
        torch.cuda.manual_seed(args.manual_seed)
        random.seed(args.manual_seed)
        np.random.seed(args.manual_seed)
        logger.info('Random Seed: {}'.format(int(args.manual_seed)))
        args.random_seed = args.manual_seed  # save seed into args
    else:
        seed = np.uint32(random.randrange(sys.maxsize))
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)
        np.random.seed(np.uint32(seed))
        logger.info('Random Seed: {}'.format(seed))
        args.random_seed = seed  # save seed into args

    if args.resume:
        logger.info('Resume training')
    else:
        logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4))  # print args
    torch.save(vars(args), os.path.join(args.save_dir, 'args_dict.pth'))  # save args

    num_classes = args.num_classes
    in_features = args.in_features if args.in_features != 0 else num_classes
    # num_domains = len(args.source_datasets) + len(args.target_datasets)

    # tfboard
    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        tfboard_dir = os.path.join(args.save_dir, 'tfboard')
        if not os.path.isdir(tfboard_dir):
            os.makedirs(tfboard_dir)
        writer = SummaryWriter(tfboard_dir)

    # resume
    if args.resume:
        try:
            checkpoints = io_utils.load_latest_checkpoints(args.save_dir, args, logger)
        except FileNotFoundError:
            logger.warning('Latest checkpoints are not found! Trying to load best model...')
            checkpoints = io_utils.load_best_checkpoints(args.save_dir, args, logger)

        start_iter = checkpoints[0]['iteration'] + 1
    else:
        start_iter = 1

    ###################################################################################################################
    #                                               Data Loading                                                      #
    ###################################################################################################################

    # train_dataset = MNIST('/data/jihun/MNIST', train=True, transform=mnist_transform, download=True)
    # val_dataset = MNIST('/data/jihun/MNIST', train=False, transform=mnist_transform, download=True)
    train_dataset = SVHN(root='/data/jihun/SVHN', transform=svhn_transform, download=True)
    val_dataset = SVHN(root='/data/jihun/SVHN', split='test', transform=svhn_transform, download=True)
    train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                            num_workers=args.num_workers, drop_last=True, pin_memory=True)
    val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True,
                                          num_workers=args.num_workers, drop_last=True, pin_memory=True)

    train_dataloader_iters = enumerate(train_dataloader)
    val_dataloader_iter = enumerate(val_dataloader)

    ###################################################################################################################
    #                                               Model Loading                                                     #
    ###################################################################################################################
    model = get_model(args.model_name, 10, 0, 2, pretrained=True)

    params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay,
                                  double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor)

    # teacher model
    # print(teacher_model)
    print('------------------------model load------------------------')
    model.load_state_dict(torch.load(args.teacher_model_path)['model'])
    for name, p in model.state_dict().items():
        if ('fc' in name) or 'bns.1' in name:
            continue
        else:
            p.requires_grad = False

    torch.nn.init.xavier_uniform_(model.fc1.weight)
    torch.nn.init.xavier_uniform_(model.fc2.weight)

    model.train(True)
    model = model.cuda(args.gpu)

    ###################################################################################################################
    #                                               Train Configurations                                              #
    ###################################################################################################################
    ce_loss = nn.CrossEntropyLoss()

    optimizer = optim.Adam(params, betas=(args.beta1, args.beta2))

    # Train Starts
    logger.info('Train Starts')
    monitor = Monitor()

    global best_accuracy
    global best_accuracies_each_c
    global best_mean_val_accuracies
    global best_total_val_accuracies
    best_accuracy = 0.0
    best_accuracies_each_c = []
    best_mean_val_accuracies = []
    best_total_val_accuracies = []

    for i in range(start_iter, args.early_stop_step + 1):
        try:
            _, (x_s, y_s) = train_dataloader_iters.__next__()
        except StopIteration:
            train_dataloader_iters = enumerate(train_dataloader)
            _, (x_s, y_s) = train_dataloader_iters.__next__()
        # init optimizer
        optimizer.zero_grad()
        # ship to cuda
        x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu)
        pred_s, f_s = model(x_s, 1 * torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu),
                            with_ft=True)
        loss = ce_loss(pred_s, y_s)
        monitor.update({"Loss/Closs_src": float(loss)})

        loss.backward()
        optimizer.step()

        if i % args.save_interval == 0 and i != 0:
            # print('------------------------%d val start------------------------' % (i))
            logger.info("Elapsed Time: {}".format(datetime.datetime.now() - start_time))
            logger.info("Start Evaluation at {:d}".format(i))

            total_val_accuracies = []
            mean_val_accuracies = []
            val_accuracies_each_c = []
            model.eval()

            val_dataloader_iter = enumerate(val_dataloader)

            pred_vals = []
            y_vals = []
            x_val = None
            y_val = None
            # print('------------------------dataload------------------------')
            with torch.no_grad():
                for j, (x_val, y_val) in val_dataloader_iter:
                    y_vals.append(y_val.cpu())
                    x_val = x_val.cuda(args.gpu)
                    y_val = y_val.cuda(args.gpu)

                    pred_val = model(x_val, 1 * torch.ones_like(y_val), with_ft=False)

                    pred_vals.append(pred_val.cpu())
            # print('------------------------acc compute------------------------')
            pred_vals = torch.cat(pred_vals, 0)
            y_vals = torch.cat(y_vals, 0)
            total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])
            val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                           class_idx=c, topk=(1,))[0]))
                                   for c, c_name in enumerate(val_dataset.classes)]

            logger.info('\n{} Accuracy of Each class\n'.format(args.finetune_dataset) +
                        ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                                 for c_name, c_val_acc in val_accuracy_each_c]))
            mean_val_accuracy = float(
                torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))
            # print('------------------------mean acc------------------------')
            logger.info('{} mean Accuracy: {:.2f}%'.format(
                args.finetune_dataset, 100 * mean_val_accuracy))
            logger.info(
                '{} Accuracy: {:.2f}%'.format(args.finetune_dataset,
                                              total_val_accuracy * 100))

            total_val_accuracies.append(total_val_accuracy)
            val_accuracies_each_c.append(val_accuracy_each_c)
            mean_val_accuracies.append(mean_val_accuracy)
            # print('------------------------tf board------------------------')
            if args.use_tfboard:
                writer.add_scalar('Val_acc', total_val_accuracy, i)
                for c_name, c_val_acc in val_accuracy_each_c:
                    writer.add_scalar('Val_acc_of_{}'.format(c_name), c_val_acc)

            model.train(True)  # train mode

            val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies)))

            del x_val, y_val, pred_val, pred_vals, y_vals
            del val_dataloader_iter
            print("%d th iter accuracy: %.3f" % (i, val_accuracy))
            # print('------------------------save model------------------------')
            if val_accuracy > best_accuracy:
                # save best model
                best_accuracy = val_accuracy
                best_accuracies_each_c = val_accuracies_each_c
                best_mean_val_accuracies = mean_val_accuracies
                best_total_val_accuracies = total_val_accuracies
                options = io_utils.get_model_options_from_args(args, i)
                # dict to save
                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}

                # save best checkpoint
                io_utils.save_checkpoints(args.save_dir, options, i, model_dict, optimizer_dict,
                                          logger, best=True)
                # ship to cuda
                model = model.cuda(args.gpu)

                # save best result into textfile
                contents = [' '.join(sys.argv) + '\n',
                            "best accuracy: {:.2f}%\n".format(best_accuracy)]
                best_accuracy_each_c = best_accuracies_each_c[0]
                best_mean_val_accuracy = best_mean_val_accuracies[0]
                best_total_val_accuracy = best_total_val_accuracies[0]
                contents.extend(["{}\n".format(args.finetune_dataset),
                                 "best total acc: {:.2f}%\n".format(100 * best_total_val_accuracy),
                                 "best mean acc: {:.2f}%\n".format(100 * best_mean_val_accuracy),
                                 'Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc)
                                                          for _, c_val_acc in best_accuracy_each_c]) + '\n'])

                best_result_path = os.path.join('./output', '{}_best_result.txt'.format(
                    os.path.splitext(os.path.basename(__file__))[0]))
                with open(best_result_path, 'a+') as f:
                    f.writelines(contents)

    val_dataloader_iter = enumerate(val_dataloader)

    pred_vals = []
    y_vals = []
    x_val = None
    y_val = None
    # print('------------------------dataload------------------------')
    with torch.no_grad():
        for j, (x_val, y_val) in val_dataloader_iter:
            y_vals.append(y_val.cpu())
            x_val = x_val.cuda(args.gpu)
            y_val = y_val.cuda(args.gpu)

            pred_val = model(x_val, 0 * torch.ones_like(y_val), with_ft=False)

            pred_vals.append(pred_val.cpu())

    pred_vals = torch.cat(pred_vals, 0)
    y_vals = torch.cat(y_vals, 0)
    total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])
    val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                   class_idx=c, topk=(1,))[0]))
                           for c, c_name in enumerate(val_dataset.classes)]
    for cls in val_accuracy_each_c:
        print(cls)
    print(total_val_accuracy)