Пример #1
0
def train(train_loader, model, criterion, optimizer, epoch, args, writer):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.3f")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix=f"Epoch: [{epoch}]",
    )

    # switch to train mode
    model.train()

    batch_size = train_loader.batch_size
    num_batches = len(train_loader)
    end = time.time()
    for i, (images, target) in tqdm.tqdm(
        enumerate(train_loader), ascii=True, total=len(train_loader)
    ):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)

        ### for MNIST
        #images = images.expand()
        #import pdb
        #pdb.set_trace()
        
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)

        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            t = (num_batches * epoch + i) * batch_size
            progress.display(i)
            progress.write_to_tensorboard(writer, prefix="train", global_step=t)

    return top1.avg, top5.avg
Пример #2
0
def validate(val_loader, model, criterion, args, writer, epoch):
    batch_time = AverageMeter("Time", ":6.3f", write_val=False)
    losses = AverageMeter("Loss", ":.3f", write_val=False)
    top1 = AverageMeter("Acc@1", ":6.2f", write_val=False)
    top5 = AverageMeter("Acc@5", ":6.2f", write_val=False)
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
                             prefix="Test: ")

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in tqdm.tqdm(enumerate(val_loader),
                                             ascii=True,
                                             total=len(val_loader)):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)

            target = target.cuda(args.gpu, non_blocking=True)

            # YHT modification
            '''
            This will severely influence the generalization! drop this.
            if args.seed is not None and args.prandom:
                torch.manual_seed(args.seed)
                torch.cuda.manual_seed(args.seed)
                torch.cuda.manual_seed_all(args.seed)
            '''
            # End of modification
            # compute output
            output = model(images)

            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        progress.display(len(val_loader))

        if writer is not None:
            progress.write_to_tensorboard(writer,
                                          prefix="test",
                                          global_step=epoch)

    return top1.avg, top5.avg
Пример #3
0
def validate(val_loader, model, criterion, args, writer, epoch):
    batch_time = AverageMeter("Time", ":6.3f", write_val=False)
    losses = AverageMeter("Loss", ":.3f", write_val=False)
    top1 = AverageMeter("Acc@1", ":6.2f", write_val=False)
    top5 = AverageMeter("Acc@5", ":6.2f", write_val=False)
    progress = ProgressMeter(val_loader.num_batches,
                             [batch_time, losses, top1, top5],
                             args,
                             prefix="Test: ")

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()

        # confusion_matrix = torch.zeros(args.num_cls,args.num_cls)
        for i, data in enumerate(val_loader):
            # images, target = data[0]['data'], data[0]['label'].long().squeeze()
            images, target = data[0].cuda(), data[1].long().squeeze().cuda()

            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            # print(target,torch.mean(images),acc1,acc5,loss,torch.mean(output))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # _, preds = torch.max(output, 1)
            # for t, p in zip(target.view(-1), preds.view(-1)):
            #     confusion_matrix[t.long(), p.long()] += 1

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        progress.display(val_loader.num_batches)

        if writer is not None:
            progress.write_to_tensorboard(writer,
                                          prefix="test",
                                          global_step=epoch)

    # torch.save(confusion_matrix,'./conf_mat.pt')
    # print(top1.count)
    return top1.avg, top5.avg
Пример #4
0
def validate(val_loader, model, criterion, args, writer, epoch):
    # batch_time = AverageMeter("Time", ":6.3f", write_val=False)
    losses = AverageMeter("Loss", ":.3f", write_val=False)
    top1 = AverageMeter("Acc@1", ":6.2f", write_val=False)
    top5 = AverageMeter("Acc@5", ":6.2f", write_val=False)
    #progress = ProgressMeter(
    #    len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
    #)
    progress = ProgressMeter(len(val_loader), [losses, top1, top5],
                             prefix="Test: ")
    # switch to evaluate mode
    model.eval()
    printModelScore(model, args)
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in tqdm.tqdm(enumerate(val_loader),
                                             ascii=True,
                                             total=len(val_loader)):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)

            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)

            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            # batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        progress.display(len(val_loader))

        if writer is not None:
            progress.write_to_tensorboard(writer,
                                          prefix="test",
                                          global_step=epoch)

    return top1.avg, top5.avg, losses.avg
Пример #5
0
def test(args, model, val_dataset, domain_num):
    val_dataloader = util_data.DataLoader(val_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_workers,
                                          drop_last=True,
                                          pin_memory=True)
    val_dataloader_iter = enumerate(val_dataloader)

    val_accs_each_c = []
    pred_ys = []
    y_vals = []
    x_val = None
    y_val = None

    model.eval()

    with torch.no_grad():
        for j, (x_val, y_val) in val_dataloader_iter:
            y_vals.append(y_val.cpu())
            x_val = x_val.cuda(args.gpu)
            y_val = y_val.cuda(args.gpu)

            pred_y = model(x_val,
                           domain_num * torch.ones_like(y_val),
                           with_ft=False)
            pred_ys.append(pred_y.cpu())
            # break

    pred_ys = torch.cat(pred_ys, 0)
    y_vals = torch.cat(y_vals, 0)
    val_acc = float(eval_utils.accuracy(pred_ys, y_vals, topk=(1, ))[0])
    val_acc_each_c = [(c_name,
                       float(
                           eval_utils.accuracy_of_c(pred_ys,
                                                    y_vals,
                                                    class_idx=c,
                                                    topk=(1, ))[0]))
                      for c, c_name in enumerate(val_dataset.classes)]
    model.train(True)
    return model, val_acc
Пример #6
0
def test(args, dataset, model):
    dataloader = util_data.DataLoader(dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      drop_last=True,
                                      pin_memory=True)
    dataloader_iter = enumerate(dataloader)
    model.eval()
    pred_vals = []
    y_vals = []
    x_val = None
    y_val = None
    # print('------------------------dataload------------------------')
    with torch.no_grad():
        for j, (x_val, y_val) in dataloader_iter:
            y_vals.append(y_val.cpu())
            x_val = x_val.cuda(args.gpu)
            y_val = y_val.cuda(args.gpu)

            pred_val = model(x_val, 0 * torch.ones_like(y_val), with_ft=False)

            pred_vals.append(pred_val.cpu())

    pred_vals = torch.cat(pred_vals, 0)
    y_vals = torch.cat(y_vals, 0)
    total_test_accuracy = float(
        eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0])
    # test2_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
    #                                                                  class_idx=c, topk=(1,))[0]))
    #                          for c, c_name in enumerate(dataset.classes)]

    print('Test accuracy for domain: ', dataset.domain)
    print('mean acc: ', total_test_accuracy)
    # for item in test2_accuracy_each_c:
    #     print(item[0], item[1])

    return
Пример #7
0
def train(args, model, train_dataset, val_dataset, save_dir, domain_num):
    train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                            num_workers=args.num_workers, drop_last=True, pin_memory=True)
    train_dataloader_iters = enumerate(train_dataloader)

    model.train(True)
    model = model.cuda(args.gpu)

    params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay,
                                  double_bias_lr=True, base_weight_factor=0.1)

    optimizer = optim.Adam(params, betas=(0.9, 0.999))
    ce_loss = nn.CrossEntropyLoss()

    writer = SummaryWriter(log_dir=join(save_dir, 'logs'))
    print('domain_num: ', domain_num)
    global best_accuracy
    global best_accuracies_each_c
    global best_mean_val_accuracies
    global best_total_val_accuracies

    best_accuracy = 0.0
    best_accuracies_each_c = []
    best_mean_val_accuracies = []
    best_total_val_accuracies = []

    for i in range(args.iter):
        try:
            _, (x_s, y_s) = train_dataloader_iters.__next__()
        except StopIteration:
            train_dataloader_iters = enumerate(train_dataloader)
            _, (x_s, y_s) = train_dataloader_iters.__next__()

        optimizer.zero_grad()

        x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu)
        domain_idx = torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu)
        pred, f = model(x_s, domain_num * domain_idx, with_ft=True)
        loss = ce_loss(pred, y_s)
        writer.add_scalar("Train Loss", loss, i)
        loss.backward()
        optimizer.step()

        if (i % 500 == 0 and i != 0):
            # print('------%d val start' % (i))
            model.eval()
            total_val_accuracies = []
            mean_val_accuracies = []
            val_accuracies_each_c = []
            model.eval()

            val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True,
                                                  num_workers=args.num_workers, drop_last=True, pin_memory=True)
            val_dataloader_iter = enumerate(val_dataloader)

            pred_vals = []
            y_vals = []
            x_val = None
            y_val = None
            # print('------------------------dataload------------------------')
            with torch.no_grad():
                for j, (x_val, y_val) in val_dataloader_iter:
                    y_vals.append(y_val.cpu())
                    x_val = x_val.cuda(args.gpu)
                    y_val = y_val.cuda(args.gpu)

                    pred_val = model(x_val, domain_num * torch.ones_like(y_val), with_ft=False)

                    pred_vals.append(pred_val.cpu())

            pred_vals = torch.cat(pred_vals, 0)
            y_vals = torch.cat(y_vals, 0)
            total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])
            val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                           class_idx=c, topk=(1,))[0]))
                                   for c, c_name in enumerate(val_dataset.classes)]

            mean_val_accuracy = float(
                torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))
            total_val_accuracies.append(total_val_accuracy)
            val_accuracies_each_c.append(val_accuracy_each_c)
            mean_val_accuracies.append(mean_val_accuracy)

            val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies)))
            print('%d th iteration accuracy: %f ' % (i, val_accuracy))
            del x_val, y_val, pred_val, pred_vals, y_vals
            del val_dataloader_iter

            # train mode
            if val_accuracy > best_accuracy:
                best_accuracy = val_accuracy
                best_accuracies_each_c = val_accuracies_each_c
                best_mean_val_accuracies = mean_val_accuracies
                best_total_val_accuracies = total_val_accuracies
                # print('%d iter val acc %.3f' % (i, val_accuracy))
                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}

                # save best checkpoint
                io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=True)
            model.train(True)
            model = model.cuda(args.gpu)

        if (i % 10000 == 0 and i != 0):
            print('%d iter complete' % (i))
            model_dict = {'model': model.cpu().state_dict()}
            optimizer_dict = {'optimizer': optimizer.state_dict()}

            # save best checkpoint
            io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=False)

    writer.flush()
    writer.close()

    return
Пример #8
0
def main():
    args = parse_args()
    args.dsbn = True if 'dsbn' in args.model_name else False  # set dsbn
    args.source_dataset = '|'.join(args.source_datasets)
    args.target_dataset = '|'.join(args.target_datasets)
    torch.cuda.set_device(args.gpu)  # set current gpu device id so pin_momory works on the target gpu
    start_time = datetime.datetime.now()  # execution start time

    # make save_dir
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    # create log file
    log_filename = 'train_records.log'
    log_path = os.path.join(args.save_dir, log_filename)
    logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO,
                                 print_level=logging.INFO if args.print_console else None,
                                 mode='a' if args.resume else 'w')

    # set num_classes by checking exp_setting
    if args.num_classes == 0:
        if args.exp_setting == 'digits':
            logger.warning('num_classes are not 10! set to 10.')
            args.num_classes = 10
        elif args.exp_setting == 'office':
            logger.warning('num_classes are not 31! set to 31.')
            args.num_classes = 31
        elif args.exp_setting in ['visda', 'imageclef']:
            logger.warning('num_classes are not 12! set to 12.')
            args.num_classes = 12
        elif args.exp_setting in ['office-home']:
            logger.warning('num_classes are not 65! set to 65.')
            args.num_classes = 65
        elif args.exp_setting in ['office-caltech']:
            args.num_classes = 10
        else:
            raise AttributeError('Wrong num_classes: {}'.format(args.num_classes))

    if args.weight_irm > 0.0:
        args.weight_source_irm = args.weight_irm
        args.weight_target_irm = args.weight_irm

    if(args.iters_active_irm > 0):
        weight_irm_backup = [args.weight_source_irm, args.weight_target_irm]
        args.weight_source_irm = 0
        args.weight_target_irm = 0

    if args.manual_seed:
        # set manual seed
        args.manual_seed = np.uint32(args.manual_seed)
        torch.manual_seed(args.manual_seed)
        torch.cuda.manual_seed(args.manual_seed)
        random.seed(args.manual_seed)
        np.random.seed(args.manual_seed)
        logger.info('Random Seed: {}'.format(int(args.manual_seed)))
        args.random_seed = args.manual_seed  # save seed into args
    else:
        seed = np.uint32(random.randrange(sys.maxsize))
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)
        np.random.seed(np.uint32(seed))
        logger.info('Random Seed: {}'.format(seed))
        args.random_seed = seed  # save seed into args

    if args.resume:
        logger.info('Resume training')
    else:
        logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4))  # print args
    torch.save(vars(args), os.path.join(args.save_dir, 'args_dict.pth'))  # save args

    num_classes = args.num_classes
    in_features = args.in_features if args.in_features != 0 else num_classes
    num_domains = len(args.source_datasets) + len(args.target_datasets)
    if args.merge_sources:
        num_source_domains = 1
    else:
        num_source_domains = len(args.source_datasets)
    num_target_domains = len(args.target_datasets)

    # tfboard
    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        tfboard_dir = os.path.join(args.save_dir, 'tfboard')
        if not os.path.isdir(tfboard_dir):
            os.makedirs(tfboard_dir)
        writer = SummaryWriter(tfboard_dir)

    # resume
    if args.resume:
        try:
            checkpoints = io_utils.load_latest_checkpoints(args.save_dir, args, logger)
        except FileNotFoundError:
            logger.warning('Latest checkpoints are not found! Trying to load best model...')
            checkpoints = io_utils.load_best_checkpoints(args.save_dir, args, logger)

        start_iter = checkpoints[0]['iteration'] + 1
    else:
        start_iter = 1

    ###################################################################################################################
    #                                               Data Loading                                                      #
    ###################################################################################################################

    source_train_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, source_name, 'train', args.jitter))
                             for source_name in args.source_datasets]
    target_train_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_name, 'train', args.jitter))
                             for target_name in args.target_datasets]

    if args.merge_sources:
        for i in range(len(source_train_datasets)):
            if i == 0:
                merged_source_train_datasets = source_train_datasets[i]
            else:
                # concatenate dataset
                merged_source_train_datasets = merged_source_train_datasets + source_train_datasets[i]
        source_train_datasets = [merged_source_train_datasets]

    # dataloader
    source_train_dataloaders = [data.DataLoader(source_train_dataset, batch_size=args.batch_size, shuffle=True,
                                                num_workers=args.num_workers, drop_last=True, pin_memory=True)
                                for source_train_dataset in source_train_datasets]
    target_train_dataloaders = [data.DataLoader(target_train_dataset, batch_size=args.batch_size, shuffle=True,
                                                num_workers=args.num_workers, drop_last=True, pin_memory=True)
                                for target_train_dataset in target_train_datasets]

    source_train_dataloader_iters = [enumerate(source_train_dataloader) for source_train_dataloader in
                                     source_train_dataloaders]
    target_train_dataloader_iters = [enumerate(target_train_dataloader) for target_train_dataloader in
                                     target_train_dataloaders]

    # validation dataloader
    target_val_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_name, 'val', args.jitter))
                           for target_name in args.target_datasets]
    target_val_dataloaders = [data.DataLoader(target_val_dataset, batch_size=args.batch_size,
                                              shuffle=False, num_workers=args.num_workers, pin_memory=True)
                              for target_val_dataset in target_val_datasets]

    ###################################################################################################################
    #                                               Model Loading                                                     #
    ###################################################################################################################
    model = get_model(args.model_name, args.num_classes, args.in_features, num_domains=num_domains, pretrained=True)

    model.train(True)
    if args.resume:
        model.load_state_dict(checkpoints[0]['model'])
    elif args.init_model_path:
        init_checkpoint = torch.load(args.init_model_path)
        model.load_state_dict(init_checkpoint['model'])
    model = model.cuda(args.gpu)

    params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay,
                                  double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor)

    if args.adv_loss:
        discriminators = [get_discriminator(args.exp_setting,
                                            in_features=args.in_features if args.in_features != 0 else args.num_classes)
                          for _ in range(num_target_domains) for _ in range(num_source_domains)]
        discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators]
        D_params = get_optimizer_params(discriminators, args.learning_rate, weight_decay=args.weight_decay,
                                        double_bias_lr=args.double_bias_lr, base_weight_factor=None)
        if args.resume:
            if checkpoints[1]:
                for d_idx, discriminator in enumerate(discriminators):
                    discriminator.load_state_dict(checkpoints[1]['discriminators'][d_idx])

    if args.sm_loss:
        srcs_centroids = [Centroids(in_features, num_classes) for _ in range(num_source_domains)]
        trgs_centroids = [Centroids(in_features, num_classes) for _ in range(num_target_domains)]

        if args.resume:
            if checkpoints[2]:
                for src_idx, src_centroids in enumerate(srcs_centroids):
                    src_centroids.load_state_dict(checkpoints[2]['srcs_centroids'][src_idx])
                for trg_idx, trg_centroids in enumerate(trgs_centroids):
                    trg_centroids.load_state_dict(checkpoints[2]['trgs_centroids'][trg_idx])

        srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids]
        trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids]

    ###################################################################################################################
    #                                               Train Configurations                                              #
    ###################################################################################################################
    ce_loss = nn.CrossEntropyLoss()
    bce_loss = nn.BCEWithLogitsLoss()
    # mse_loss = nn.MSELoss()

    lr_scheduler = LRScheduler(args.learning_rate, args.warmup_learning_rate, args.warmup_step,
                               num_steps=args.max_step,
                               alpha=10, beta=0.75, double_bias_lr=args.double_bias_lr,
                               base_weight_factor=args.base_weight_factor)

    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(params, momentum=0.9, nesterov=True)
    else:
        optimizer = optim.Adam(params, betas=(args.beta1, args.beta2))

    if args.resume:
        if checkpoints[1]:
            optimizer.load_state_dict(checkpoints[1]['optimizer'])

    if args.adv_loss:
        if args.optimizer.lower() == 'sgd':
            optimizer_D = optim.SGD(D_params, momentum=0.9, nesterov=True)
        else:
            optimizer_D = optim.Adam(D_params, betas=(args.beta1, args.beta2))

        if args.resume:
            if checkpoints[1]:
                optimizer_D.load_state_dict(checkpoints[1]['optimizer_D'])

    # Train Starts
    logger.info('Train Starts')
    domain_loss_adjust_factor = args.domain_loss_adjust_factor

    monitor = Monitor()

    global best_accuracy
    global best_accuracies_each_c
    global best_mean_val_accuracies
    global best_total_val_accuracies
    best_accuracy = 0.0
    best_accuracies_each_c = []
    best_mean_val_accuracies = []
    best_total_val_accuracies = []

    for i_iter in range(start_iter, args.early_stop_step + 1):

        if(args.iters_active_irm > 0):
            if(i_iter > args.iters_active_irm):
                args.weight_source_irm = weight_irm_backup[0]
                args.weight_target_irm = weight_irm_backup[1]

        src_inputs = []
        for src_dataloader_idx in range(len(source_train_dataloader_iters)):
            try:
                _, (x_s, y_s) = source_train_dataloader_iters[src_dataloader_idx].__next__()
                src_inputs.append((x_s, y_s))
            except StopIteration:
                source_train_dataloader_iters[src_dataloader_idx] = enumerate(
                    source_train_dataloaders[src_dataloader_idx])
                _, (x_s, y_s) = source_train_dataloader_iters[src_dataloader_idx].__next__()
                src_inputs.append((x_s, y_s))

        trg_inputs = []
        for trg_dataloader_idx in range(len(target_train_dataloader_iters)):
            try:
                _, (x_t, _) = target_train_dataloader_iters[trg_dataloader_idx].__next__()
                trg_inputs.append((x_t, None))
            except StopIteration:
                target_train_dataloader_iters[trg_dataloader_idx] = enumerate(
                    target_train_dataloaders[trg_dataloader_idx])
                _, (x_t, _) = target_train_dataloader_iters[trg_dataloader_idx].__next__()
                trg_inputs.append((x_t, None))

        current_lr = lr_scheduler.current_lr(i_iter)
        adaptation_lambda = adaptation_factor((i_iter - args.warmup_step) / float(args.max_step),
                                              gamma=args.adaptation_gamma)
        # init optimizer
        optimizer.zero_grad()
        lr_scheduler(optimizer, i_iter)
        if args.adv_loss:
            optimizer_D.zero_grad()
            lr_scheduler(optimizer_D, i_iter)

        ########################################################################################################
        #                                               Train G                                                #
        ########################################################################################################
        if args.adv_loss:
            for discriminator in discriminators:
                for param in discriminator.parameters():
                    param.requires_grad = False
        # ship to cuda
        src_inputs = [(x_s.cuda(args.gpu), y_s.cuda(args.gpu)) for (x_s, y_s) in src_inputs]
        trg_inputs = [(x_t.cuda(args.gpu), None) for (x_t, _) in trg_inputs]

        if args.dsbn:
            src_preds = []
            for src_idx, (x_s, y_s) in enumerate(src_inputs):
                pred_s, f_s = model(x_s, src_idx * torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu),
                                    with_ft=True)
                src_preds.append((pred_s, f_s))

            trg_preds = []
            for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
                pred_t, f_t = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu),
                                    with_ft=True)
                trg_preds.append((pred_t, f_t))
        else:
            src_preds = []
            for src_idx, (x_s, y_s) in enumerate(src_inputs):
                pred_s, f_s = model(x_s, with_ft=True)
                src_preds.append((pred_s, f_s))

            trg_preds = []
            for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
                pred_t, f_t = model(x_t, with_ft=True)
                trg_preds.append((pred_t, f_t))

        Closs_src = 0
        Closs_src_irm = 0
        for (_, y_s), (pred_s, f_s) in zip(src_inputs, src_preds):
            Closs_src = Closs_src + ce_loss(pred_s, y_s) / float(num_source_domains)

            if(args.weight_source_irm > 0):
                Closs_src_irm += feature_penalty(f_s, model.fc, ce_loss, y_s)
                

        monitor.update({"Loss/Closs_src": float(Closs_src)})
        Floss = Closs_src

        if(args.weight_source_irm > 0):
            Floss += Closs_src_irm * args.weight_source_irm
            monitor.update({"Loss/Closs_src_irm": float(Closs_src_irm)})

        if args.adv_loss:
            # adversarial loss
            Gloss = 0
            for trg_idx, (_, f_t) in enumerate(trg_preds):
                for src_idx, (_, f_s) in enumerate(src_preds):
                    Dout_s = discriminators[trg_idx * num_source_domains + src_idx](f_s)
                    source_label = torch.zeros_like(Dout_s).cuda(args.gpu)
                    loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label) / 2

                    Dout_t = discriminators[trg_idx * num_source_domains + src_idx](f_t)
                    target_label = torch.ones_like(Dout_t).cuda(args.gpu)
                    loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label) / 2
                    Gloss = Gloss - (loss_adv_src + loss_adv_trg)
            Gloss = Gloss / float(num_target_domains * num_source_domains)
            monitor.update({'Loss/Gloss': float(Gloss)})

            Floss = Floss + adaptation_lambda * Gloss

        # # pseudo label generation
        # pred_t_pseudos = []
        # if args.dsbn:
        #     with torch.no_grad():
        #         model.eval()
        #         for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
        #             pred_t_pseudo = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu),
        #                                   with_ft=False)
        #             pred_t_pseudos.append(pred_t_pseudo)
        #         model.train(True)
        # else:
        #     with torch.no_grad():
        #         model.eval()
        #         for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
        #             pred_t_pseudo = model(x_t, with_ft=False)
        #             pred_t_pseudos.append(pred_t_pseudo)
        #         model.train(True)

        # pseudo label generation
        pred_t_pseudos = []
        if args.dsbn:
            with torch.no_grad():
                model.eval()
                for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
                    pred_t_pseudo = model(x_t, trg_idx * torch.ones(x_t.shape[0], dtype=torch.long).cuda(args.gpu),
                                        with_ft=False)
                    pred_t_pseudos.append(pred_t_pseudo)
                model.train(True)
        else:
            with torch.no_grad():
                model.eval()
                for trg_idx, (x_t, _) in enumerate(trg_inputs, num_source_domains):
                    pred_t_pseudo = model(x_t, with_ft=False)
                    pred_t_pseudos.append(pred_t_pseudo)
                model.train(True)

        if(args.weight_target_irm > 0):
            Closs_trg_irm = 0
            #Closs_trg = 0
            for pred_t_pseudo, (pred_t, f_t) in zip(pred_t_pseudos, trg_preds):
                y_t_pseudo = torch.argmax(pred_t_pseudo, 1).detach()
                #Closs_trg = Closs_trg + ce_loss(pred_t, y_t_pseudo)
                Closs_trg_irm += feature_penalty(f_t, model.fc, ce_loss, y_t_pseudo)

            #Floss += Closs_trg
            Floss += Closs_trg_irm * args.weight_target_irm
            
            monitor.update({"Loss/Closs_trg_irm": float(Closs_trg_irm)})
            #monitor.update({"Loss/Closs_trg": float(Closs_trg_irm)})

        # moving semantic loss
        if args.sm_loss:
            current_srcs_centroids = [src_centroids(f_s, y_s) for src_centroids, (x_s, y_s), (_, f_s) in
                                      zip(srcs_centroids, src_inputs, src_preds)]

            current_trgs_centroids = [trg_centroids(f_t, torch.argmax(pred_t_pseudo, 1).detach()) for
                                      trg_centroids, pred_t_pseudo, (_, f_t) in
                                      zip(trgs_centroids, pred_t_pseudos, trg_preds)]

            semantic_loss = 0
            for current_trg_centroids in current_trgs_centroids:
                for current_src_centroids in current_srcs_centroids:
                    semantic_loss = semantic_loss + args.sm_etha * semantic_loss_calc(current_src_centroids,
                                                                                      current_trg_centroids)
            semantic_loss = semantic_loss / float(num_target_domains * num_source_domains)
            monitor.update({'Loss/SMloss': float(semantic_loss)})

            Floss = Floss + adaptation_lambda * semantic_loss

   
        # Floss backward
        Floss.backward()
        optimizer.step()
        ########################################################################################################
        #                                               Train D                                                #
        ########################################################################################################
        if args.adv_loss:
            for discriminator in discriminators:
                for param in discriminator.parameters():
                    param.requires_grad = True

        if args.adv_loss:
            # adversarial loss
            Dloss = 0
            for trg_idx, (_, f_t) in enumerate(trg_preds):
                for src_idx, (_, f_s) in enumerate(src_preds):
                    Dout_s = discriminators[trg_idx * num_source_domains + src_idx](f_s.detach())
                    source_label = torch.zeros_like(Dout_s).cuda(args.gpu)
                    loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label) / 2

                    # target
                    Dout_t = discriminators[trg_idx * num_source_domains + src_idx](f_t.detach())
                    target_label = torch.ones_like(Dout_t).cuda(args.gpu)
                    loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label) / 2
                    Dloss = Dloss + loss_adv_src + loss_adv_trg
            Dloss = Dloss / float(num_target_domains * num_source_domains)
            monitor.update({'Loss/Dloss': float(Dloss)})
            Dloss = adaptation_lambda * Dloss
            Dloss.backward()
            optimizer_D.step()

        if args.sm_loss:
            for src_centroids, current_src_centroids in zip(srcs_centroids, current_srcs_centroids):
                src_centroids.centroids.data = current_src_centroids.data
            for trg_centroids, current_trg_centroids in zip(trgs_centroids, current_trgs_centroids):
                trg_centroids.centroids.data = current_trg_centroids.data

        if i_iter % args.disp_interval == 0 and i_iter != 0:
            disp_msg = 'iter[{:8d}/{:8d}], '.format(i_iter, args.early_stop_step)
            disp_msg += str(monitor)
            if args.adv_loss or args.sm_loss:
                disp_msg += ', lambda={:.6f}'.format(adaptation_lambda)
            disp_msg += ', lr={:.6f}'.format(current_lr)
            logger.info(disp_msg)

            if args.use_tfboard:
                if args.save_model_hist:
                    for name, param in model.named_parameters():
                        writer.add_histogram(name, param.cpu().data.numpy(), i_iter, bins='auto')

                for k, v in monitor.losses.items():
                    writer.add_scalar(k, v, i_iter)
                if args.adv_loss or args.sm_loss:
                    writer.add_scalar('adaptation_lambda', adaptation_lambda, i_iter)
                writer.add_scalar('learning rate', current_lr, i_iter)
            monitor.reset()

        if i_iter % args.save_interval == 0 and i_iter != 0:
            logger.info("Elapsed Time: {}".format(datetime.datetime.now() - start_time))
            logger.info("Start Evaluation at {:d}".format(i_iter))

            target_val_dataloader_iters = [enumerate(target_val_dataloader)
                                           for target_val_dataloader in target_val_dataloaders]

            total_val_accuracies = []
            mean_val_accuracies = []
            val_accuracies_each_c = []
            model.eval()  # evaluation mode
            for trg_idx, target_val_dataloader_iter in enumerate(target_val_dataloader_iters, num_source_domains):
                pred_vals = []
                y_vals = []
                x_val = None
                y_val = None
                pred_val = None
                with torch.no_grad():
                    for i, (x_val, y_val) in target_val_dataloader_iter:
                        y_vals.append(y_val.cpu())
                        x_val = x_val.cuda(args.gpu)
                        y_val = y_val.cuda(args.gpu)

                        if args.dsbn:
                            pred_val = model(x_val, trg_idx * torch.ones_like(y_val), with_ft=False)
                        else:
                            pred_val = model(x_val, with_ft=False)

                        pred_vals.append(pred_val.cpu())

                pred_vals = torch.cat(pred_vals, 0)
                y_vals = torch.cat(y_vals, 0)
                total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])

                val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                               class_idx=c, topk=(1,))[0]))
                                       for c, c_name in
                                       enumerate(target_val_datasets[trg_idx - num_source_domains].classes)]
                logger.info('\n{} Accuracy of Each class\n'.format(args.target_datasets[trg_idx - num_source_domains]) +
                            ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                                     for c_name, c_val_acc in val_accuracy_each_c]))
                mean_val_accuracy = float(
                    torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))

                logger.info('{} mean Accuracy: {:.2f}%'.format(
                    args.target_datasets[trg_idx - num_source_domains], 100 * mean_val_accuracy))
                logger.info(
                    '{} Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx - num_source_domains],
                                                  total_val_accuracy * 100))

                total_val_accuracies.append(total_val_accuracy)
                val_accuracies_each_c.append(val_accuracy_each_c)
                mean_val_accuracies.append(mean_val_accuracy)

                if args.use_tfboard:
                    writer.add_scalar('Val_acc', total_val_accuracy, i_iter)
                    for c_name, c_val_acc in val_accuracy_each_c:
                        writer.add_scalar('Val_acc_of_{}'.format(c_name), c_val_acc)
            model.train(True)  # train mode

            if args.exp_setting.lower() == 'visda':
                val_accuracy = float(torch.mean(torch.FloatTensor(mean_val_accuracies)))
            else:
                val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies)))

            # for memory
            del x_val, y_val, pred_val, pred_vals, y_vals
            for target_val_dataloader_iter in target_val_dataloader_iters:
                del target_val_dataloader_iter
            del target_val_dataloader_iters

            if val_accuracy > best_accuracy:
                # save best model
                best_accuracy = val_accuracy
                best_accuracies_each_c = val_accuracies_each_c
                best_mean_val_accuracies = mean_val_accuracies
                best_total_val_accuracies = total_val_accuracies
                options = io_utils.get_model_options_from_args(args, i_iter)
                # dict to save
                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}
                if args.adv_loss:
                    optimizer_dict.update({'optimizer_D': optimizer_D.state_dict(),
                                           'discriminators': [discriminator.cpu().state_dict()
                                                              for discriminator in discriminators],
                                           'source_datasets': args.source_datasets,
                                           'target_datasets': args.target_datasets})
                centroids_dict = {}
                if args.sm_loss:
                    centroids_dict = {
                        'srcs_centroids': [src_centroids.cpu().state_dict() for src_centroids in srcs_centroids],
                        'trgs_centroids': [trg_centroids.cpu().state_dict() for trg_centroids in trgs_centroids]}
                # save best checkpoint
                io_utils.save_checkpoints(args.save_dir, options, i_iter, model_dict, optimizer_dict, centroids_dict,
                                          logger, best=True)
                # ship to cuda
                model = model.cuda(args.gpu)
                if args.adv_loss:
                    discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators]
                if args.sm_loss:
                    srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids]
                    trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids]

                # save best result into textfile
                contents = [' '.join(sys.argv) + '\n',
                            "best accuracy: {:.2f}%\n".format(best_accuracy)]
                for d_idx in range(num_target_domains):
                    best_accuracy_each_c = best_accuracies_each_c[d_idx]
                    best_mean_val_accuracy = best_mean_val_accuracies[d_idx]
                    best_total_val_accuracy = best_total_val_accuracies[d_idx]
                    contents.extend(["{}2{}\n".format(args.source_dataset, args.target_datasets[d_idx]),
                                     "best total acc: {:.2f}%\n".format(100 * best_total_val_accuracy),
                                     "best mean acc: {:.2f}%\n".format(100 * best_mean_val_accuracy),
                                     'Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc)
                                                              for _, c_val_acc in best_accuracy_each_c]) + '\n'])

                best_result_path = os.path.join('./output', '{}_best_result.txt'.format(
                    os.path.splitext(os.path.basename(__file__))[0]))
                with open(best_result_path, 'a+') as f:
                    f.writelines(contents)

            # logging best model results
            for trg_idx in range(num_target_domains):
                best_accuracy_each_c = best_accuracies_each_c[trg_idx]
                best_total_val_accuracy = best_total_val_accuracies[trg_idx]
                best_mean_val_accuracy = best_mean_val_accuracies[trg_idx]
                logger.info(
                    '\nBest {} Accuracy of Each class\n'.format(args.target_datasets[trg_idx]) +
                    ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                             for c_name, c_val_acc in best_accuracy_each_c]))
                logger.info('Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc)
                                                     for _, c_val_acc in best_accuracy_each_c]))
                logger.info('Best {} mean Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx],
                                                                    100 * best_mean_val_accuracy))
                logger.info('Best {} Accuracy: {:.2f}%'.format(args.target_datasets[trg_idx],
                                                               100 * best_total_val_accuracy))
            logger.info("Best model's Average Accuracy of targets: {:.2f}".format(100 * best_accuracy))

            if args.save_ckpts:
                # get options
                options = io_utils.get_model_options_from_args(args, i_iter)
                # dict to save
                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}
                if args.adv_loss:
                    optimizer_dict.update({'optimizer_D': optimizer_D.state_dict(),
                                           'discriminators': [discriminator.cpu().state_dict()
                                                              for discriminator in discriminators]})
                centroids_dict = {}
                if args.sm_loss:
                    centroids_dict = {
                        'srcs_centroids': [src_centroids.cpu().state_dict() for src_centroids in srcs_centroids],
                        'trgs_centroids': [trg_centroids.cpu().state_dict() for trg_centroids in trgs_centroids]}
                # save checkpoint
                io_utils.save_checkpoints(args.save_dir, options, i_iter, model_dict, optimizer_dict, centroids_dict,
                                          logger, best=False)

                # ship to cuda
                model = model.cuda(args.gpu)
                if args.adv_loss:
                    discriminators = [discriminator.cuda(args.gpu) for discriminator in discriminators]
                if args.sm_loss:
                    srcs_centroids = [src_centroids.cuda(args.gpu) for src_centroids in srcs_centroids]
                    trgs_centroids = [trg_centroids.cuda(args.gpu) for trg_centroids in trgs_centroids]

    if args.use_tfboard:
        writer.close()

    logger.info('Total Time: {}'.format((datetime.datetime.now() - start_time)))
Пример #9
0
def train(train_loader, model, criterion, optimizer, epoch, cfg, writer):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.3f")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        train_loader.num_batches,
        [batch_time, data_time, losses, top1, top5],
        cfg,
        prefix=f"Epoch: [{epoch}]",
    )

    # switch to train mode
    model.train()

    batch_size = train_loader.batch_size
    num_batches = train_loader.num_batches
    end = time.time()

    for i, data in enumerate(train_loader):
        # images, target = data[0]['data'],data[0]['label'].long().squeeze()
        images, target = data[0].cuda(), data[1].long().squeeze().cuda()
        # measure data loading time
        data_time.update(time.time() - end)

        if cfg.cs_kd:

            batch_size = images.size(0)
            loss_batch_size = batch_size // 2
            targets_ = target[:batch_size // 2]
            outputs = model(images[:batch_size // 2])
            loss = torch.mean(criterion(outputs, targets_))
            # loss += loss.item()

            with torch.no_grad():
                outputs_cls = model(images[batch_size // 2:])
            cls_loss = kdloss(outputs[:batch_size // 2], outputs_cls.detach())
            lamda = 3
            loss += lamda * cls_loss
            acc1, acc5 = accuracy(outputs, targets_, topk=(1, 5))
        else:
            batch_size = images.size(0)
            loss_batch_size = batch_size
            #compute output
            output = model(images)
            loss = criterion(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

        # print(i, batch_size, loss)

        # measure accuracy and record loss

        losses.update(loss.item(), loss_batch_size)
        top1.update(acc1.item(), loss_batch_size)
        top5.update(acc5.item(), loss_batch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % cfg.print_freq == 0 or i == num_batches - 1:
            t = (num_batches * epoch + i) * batch_size
            progress.display(i)
            progress.write_to_tensorboard(writer,
                                          prefix="train",
                                          global_step=t)

    # train_loader.reset()
    # print(top1.count)
    return top1.avg, top5.avg
Пример #10
0
def evaluate(model,
             trainer,
             data_loader,
             epoch=0,
             batch_size=opt.batch_size,
             logger=None,
             tb_logger=None,
             max_iters=None):
    """ Evaluate model

    Similar to `train()` structure, where the function includes bookkeeping
    features and wrapper items. The only difference is that evaluation will
    only occur until the `max_iter` if it is specified and includes an
    `EvalMetrics` intiailization.

    The latter is currrently used to save predictions and ground truths to
    compute the confusion matrix.

    Args:
        model: Classification model
        trainer (Trainer): Training wrapper
        data_loader (torch.data.Dataloader): Generator data loading instance
        epoch (int): Current epoch
        logger (Logger): Logger. Used to display/log metrics
        tb_logger (SummaryWriter): Tensorboard Logger
        batch_size (int): Batch size
        max_iters (int): Max iterations

    Returns:
        float: Loss average
        float: Accuracy average
        float: Run time average
        EvalMetrics: Evaluation wrapper to compute CMs

    """
    criterion = trainer.criterion

    # Initialize meter and metrics
    meter = get_meter(meters=['batch_time', 'loss', 'acc'])
    predictions, gtruth, ids = [], [], []
    classes = data_loader.dataset.classes
    metrics = EvalMetrics(classes, predictions, gtruth, ids, trainer.model_dir)

    # Switch to evaluate mode
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            # process batch items: images, labels
            img = to_cuda(batch[CONST.IMG], trainer.computing_device)
            target = to_cuda(batch[CONST.LBL],
                             trainer.computing_device,
                             label=True)
            id = batch[CONST.ID]

            # compute output
            end = time.time()
            logits = model(img)
            loss = criterion(logits, target)
            acc = accuracy(logits, target)
            batch_size = list(batch[CONST.LBL].shape)[0]

            # update metrics
            meter['acc'].update(acc, batch_size)
            meter['loss'].update(loss, batch_size)

            # update metrics2
            metrics.update(logits, target, id)

            # measure elapsed time
            meter['batch_time'].update(time.time() - end, batch_size)

            if i % opt.print_freq == 0:
                log = 'EVAL [{:02d}][{:2d}/{:2d}] TIME {:10} ACC {:10} LOSS {' \
                      ':10}'.format(epoch, i, len(data_loader),
                    "{t.val:.3f} ({t.avg:.3f})".format(t=meter['batch_time']),
                    "{t.val:.3f} ({t.avg:.3f})".format(t=meter['acc']),
                    "{t.val:.3f} ({t.avg:.3f})".format(t=meter['loss'])
                                    )
                logger.info(log)

                if tb_logger is not None:
                    tb_logger.add_scalar('test/loss', meter['loss'].val, epoch)
                    tb_logger.add_scalar('test/accuracy', meter['acc'].val,
                                         epoch)

            if max_iters is not None and i >= max_iters:
                break

        # Print last eval
        log = 'EVAL [{:02d}][{:2d}/{:2d}] TIME {:10} ACC {:10} LOSS {' \
              ':10}'.format(epoch, i, len(data_loader),
                            "{t.val:.3f} ({t.avg:.3f})".format(t=meter['batch_time']),
                            "{t.val:.3f} ({t.avg:.3f})".format(t=meter['acc']),
                            "{t.val:.3f} ({t.avg:.3f})".format(t=meter['loss'])
                            )
        logger.info(log)

        if tb_logger is not None:
            tb_logger.add_scalar('test-epoch/loss', meter['loss'].avg, epoch)
            tb_logger.add_scalar('test-epoch/accuracy', meter['acc'].avg,
                                 epoch)

    return meter['loss'].avg, meter['acc'].avg, meter['batch_time'], metrics
Пример #11
0
def main():
    print('start finetune')
    args = parse_args()
    args.dsbn = True if 'dsbn' in args.model_name else False  # set dsbn
    args.cpua = True if 'cpua' in args.model_name else False

    torch.cuda.set_device(args.gpu)  # set current gpu device id so pin_momory works on the target gpu
    start_time = datetime.datetime.now()  # execution start time

    # make save_dir
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    # check whether teacher model exists
    if not os.path.isfile(args.teacher_model_path):
        raise AttributeError('Missing teacher model path: {}'.format(args.teacher_model_path))

    # create log file
    log_filename = 'train_records.log'
    log_path = os.path.join(args.save_dir, log_filename)
    logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO,
                                 print_level=logging.INFO if args.print_console else None,
                                 mode='a' if args.resume else 'w')

    # set num_classes by checking exp_setting
    if args.num_classes == 0:
        if args.exp_setting in ['office-home']:
            logger.warning('num_classes are not 65! set to 65.')
            args.num_classes = 65
        elif args.exp_setting in ['digits']:
            logger.warning('num_classes are not 10! set to 10.')
            args.num_classes = 10
        else:
            raise AttributeError('Wrong num_classes: {}'.format(args.num_classes))

    if args.manual_seed:
        # set manual seed
        args.manual_seed = np.uint32(args.manual_seed)
        torch.manual_seed(args.manual_seed)
        torch.cuda.manual_seed(args.manual_seed)
        random.seed(args.manual_seed)
        np.random.seed(args.manual_seed)
        logger.info('Random Seed: {}'.format(int(args.manual_seed)))
        args.random_seed = args.manual_seed  # save seed into args
    else:
        seed = np.uint32(random.randrange(sys.maxsize))
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)
        np.random.seed(np.uint32(seed))
        logger.info('Random Seed: {}'.format(seed))
        args.random_seed = seed  # save seed into args

    if args.resume:
        logger.info('Resume training')
    else:
        logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4))  # print args
    torch.save(vars(args), os.path.join(args.save_dir, 'args_dict.pth'))  # save args

    num_classes = args.num_classes
    in_features = args.in_features if args.in_features != 0 else num_classes
    # num_domains = len(args.source_datasets) + len(args.target_datasets)

    # tfboard
    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        tfboard_dir = os.path.join(args.save_dir, 'tfboard')
        if not os.path.isdir(tfboard_dir):
            os.makedirs(tfboard_dir)
        writer = SummaryWriter(tfboard_dir)

    # resume
    if args.resume:
        try:
            checkpoints = io_utils.load_latest_checkpoints(args.save_dir, args, logger)
        except FileNotFoundError:
            logger.warning('Latest checkpoints are not found! Trying to load best model...')
            checkpoints = io_utils.load_best_checkpoints(args.save_dir, args, logger)

        start_iter = checkpoints[0]['iteration'] + 1
    else:
        start_iter = 1

    ###################################################################################################################
    #                                               Data Loading                                                      #
    ###################################################################################################################

    # train_dataset = MNIST('/data/jihun/MNIST', train=True, transform=mnist_transform, download=True)
    # val_dataset = MNIST('/data/jihun/MNIST', train=False, transform=mnist_transform, download=True)
    train_dataset = SVHN(root='/data/jihun/SVHN', transform=svhn_transform, download=True)
    val_dataset = SVHN(root='/data/jihun/SVHN', split='test', transform=svhn_transform, download=True)
    train_dataloader = util_data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                            num_workers=args.num_workers, drop_last=True, pin_memory=True)
    val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True,
                                          num_workers=args.num_workers, drop_last=True, pin_memory=True)

    train_dataloader_iters = enumerate(train_dataloader)
    val_dataloader_iter = enumerate(val_dataloader)

    ###################################################################################################################
    #                                               Model Loading                                                     #
    ###################################################################################################################
    model = get_model(args.model_name, 10, 0, 2, pretrained=True)

    params = get_optimizer_params(model, args.learning_rate, weight_decay=args.weight_decay,
                                  double_bias_lr=args.double_bias_lr, base_weight_factor=args.base_weight_factor)

    # teacher model
    # print(teacher_model)
    print('------------------------model load------------------------')
    model.load_state_dict(torch.load(args.teacher_model_path)['model'])
    for name, p in model.state_dict().items():
        if ('fc' in name) or 'bns.1' in name:
            continue
        else:
            p.requires_grad = False

    torch.nn.init.xavier_uniform_(model.fc1.weight)
    torch.nn.init.xavier_uniform_(model.fc2.weight)

    model.train(True)
    model = model.cuda(args.gpu)

    ###################################################################################################################
    #                                               Train Configurations                                              #
    ###################################################################################################################
    ce_loss = nn.CrossEntropyLoss()

    optimizer = optim.Adam(params, betas=(args.beta1, args.beta2))

    # Train Starts
    logger.info('Train Starts')
    monitor = Monitor()

    global best_accuracy
    global best_accuracies_each_c
    global best_mean_val_accuracies
    global best_total_val_accuracies
    best_accuracy = 0.0
    best_accuracies_each_c = []
    best_mean_val_accuracies = []
    best_total_val_accuracies = []

    for i in range(start_iter, args.early_stop_step + 1):
        try:
            _, (x_s, y_s) = train_dataloader_iters.__next__()
        except StopIteration:
            train_dataloader_iters = enumerate(train_dataloader)
            _, (x_s, y_s) = train_dataloader_iters.__next__()
        # init optimizer
        optimizer.zero_grad()
        # ship to cuda
        x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu)
        pred_s, f_s = model(x_s, 1 * torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu),
                            with_ft=True)
        loss = ce_loss(pred_s, y_s)
        monitor.update({"Loss/Closs_src": float(loss)})

        loss.backward()
        optimizer.step()

        if i % args.save_interval == 0 and i != 0:
            # print('------------------------%d val start------------------------' % (i))
            logger.info("Elapsed Time: {}".format(datetime.datetime.now() - start_time))
            logger.info("Start Evaluation at {:d}".format(i))

            total_val_accuracies = []
            mean_val_accuracies = []
            val_accuracies_each_c = []
            model.eval()

            val_dataloader_iter = enumerate(val_dataloader)

            pred_vals = []
            y_vals = []
            x_val = None
            y_val = None
            # print('------------------------dataload------------------------')
            with torch.no_grad():
                for j, (x_val, y_val) in val_dataloader_iter:
                    y_vals.append(y_val.cpu())
                    x_val = x_val.cuda(args.gpu)
                    y_val = y_val.cuda(args.gpu)

                    pred_val = model(x_val, 1 * torch.ones_like(y_val), with_ft=False)

                    pred_vals.append(pred_val.cpu())
            # print('------------------------acc compute------------------------')
            pred_vals = torch.cat(pred_vals, 0)
            y_vals = torch.cat(y_vals, 0)
            total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])
            val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                           class_idx=c, topk=(1,))[0]))
                                   for c, c_name in enumerate(val_dataset.classes)]

            logger.info('\n{} Accuracy of Each class\n'.format(args.finetune_dataset) +
                        ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                                 for c_name, c_val_acc in val_accuracy_each_c]))
            mean_val_accuracy = float(
                torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))
            # print('------------------------mean acc------------------------')
            logger.info('{} mean Accuracy: {:.2f}%'.format(
                args.finetune_dataset, 100 * mean_val_accuracy))
            logger.info(
                '{} Accuracy: {:.2f}%'.format(args.finetune_dataset,
                                              total_val_accuracy * 100))

            total_val_accuracies.append(total_val_accuracy)
            val_accuracies_each_c.append(val_accuracy_each_c)
            mean_val_accuracies.append(mean_val_accuracy)
            # print('------------------------tf board------------------------')
            if args.use_tfboard:
                writer.add_scalar('Val_acc', total_val_accuracy, i)
                for c_name, c_val_acc in val_accuracy_each_c:
                    writer.add_scalar('Val_acc_of_{}'.format(c_name), c_val_acc)

            model.train(True)  # train mode

            val_accuracy = float(torch.mean(torch.FloatTensor(total_val_accuracies)))

            del x_val, y_val, pred_val, pred_vals, y_vals
            del val_dataloader_iter
            print("%d th iter accuracy: %.3f" % (i, val_accuracy))
            # print('------------------------save model------------------------')
            if val_accuracy > best_accuracy:
                # save best model
                best_accuracy = val_accuracy
                best_accuracies_each_c = val_accuracies_each_c
                best_mean_val_accuracies = mean_val_accuracies
                best_total_val_accuracies = total_val_accuracies
                options = io_utils.get_model_options_from_args(args, i)
                # dict to save
                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}

                # save best checkpoint
                io_utils.save_checkpoints(args.save_dir, options, i, model_dict, optimizer_dict,
                                          logger, best=True)
                # ship to cuda
                model = model.cuda(args.gpu)

                # save best result into textfile
                contents = [' '.join(sys.argv) + '\n',
                            "best accuracy: {:.2f}%\n".format(best_accuracy)]
                best_accuracy_each_c = best_accuracies_each_c[0]
                best_mean_val_accuracy = best_mean_val_accuracies[0]
                best_total_val_accuracy = best_total_val_accuracies[0]
                contents.extend(["{}\n".format(args.finetune_dataset),
                                 "best total acc: {:.2f}%\n".format(100 * best_total_val_accuracy),
                                 "best mean acc: {:.2f}%\n".format(100 * best_mean_val_accuracy),
                                 'Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc)
                                                          for _, c_val_acc in best_accuracy_each_c]) + '\n'])

                best_result_path = os.path.join('./output', '{}_best_result.txt'.format(
                    os.path.splitext(os.path.basename(__file__))[0]))
                with open(best_result_path, 'a+') as f:
                    f.writelines(contents)

    val_dataloader_iter = enumerate(val_dataloader)

    pred_vals = []
    y_vals = []
    x_val = None
    y_val = None
    # print('------------------------dataload------------------------')
    with torch.no_grad():
        for j, (x_val, y_val) in val_dataloader_iter:
            y_vals.append(y_val.cpu())
            x_val = x_val.cuda(args.gpu)
            y_val = y_val.cuda(args.gpu)

            pred_val = model(x_val, 0 * torch.ones_like(y_val), with_ft=False)

            pred_vals.append(pred_val.cpu())

    pred_vals = torch.cat(pred_vals, 0)
    y_vals = torch.cat(y_vals, 0)
    total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])
    val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                   class_idx=c, topk=(1,))[0]))
                           for c, c_name in enumerate(val_dataset.classes)]
    for cls in val_accuracy_each_c:
        print(cls)
    print(total_val_accuracy)
Пример #12
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu)  # set current gpu device id so pin_momory works on the target gpu
    if not os.path.isfile(args.model_path):
        raise IOError("ERROR model_path: {}".format(args.model_path))

    # load checkpoints
    checkpoint = torch.load(args.model_path)
    global_step = checkpoint['iteration']
    model_state_dict = checkpoint['model']

    # set logger
    model_dir = os.path.dirname(args.model_path)
    log_filename = 'evaluation_step{}.log'.format(global_step)
    log_path = os.path.join(model_dir, log_filename)
    logger = io_utils.get_logger(__name__, log_file=log_path, write_level=logging.INFO,
                                 print_level=logging.INFO if args.print_console else None)

    # set num_classes by checking exp_setting
    if args.num_classes == 0:
        if args.exp_setting == 'digits':
            logger.warning('num_classes are not 10! set to 10.')
            args.num_classes = 10
        elif args.exp_setting == 'office':
            logger.warning('num_classes are not 31! set to 31.')
            args.num_classes = 31
        elif args.exp_setting in ['visda', 'imageclef']:
            logger.warning('num_classes are not 12! set to 12.')
            args.num_classes = 12
        elif args.exp_setting in ['office-home']:
            logger.warning('num_classes are not 65! set to 65.')
            args.num_classes = 65
        elif args.exp_setting in ['office-caltech']:
            args.num_classes = 10
        else:
            raise AttributeError('Wrong num_classes: {}'.format(args.num_classes))

    # update model args from filename
    model_args = io_utils.get_model_args_dict_from_filename(os.path.basename(args.model_path))
    model_args['source_datasets'] = model_args['source_dataset'].split('|')
    model_args['target_datasets'] = model_args['target_dataset'].split('|')
    args.__dict__.update(model_args)
    # load args if it exists
    args_path = os.path.join(model_dir, 'args_dict.pth')
    if os.path.isfile(args_path):
        logger.info('Arguemnt file exist. load arguments from {}'.format(args_path))
        args_dict = torch.load(args_path)
        update_dict = {'args_path': args_path,
                       'source_dataset': args_dict['source_dataset'],
                       'source_datasets': args_dict['source_datasets'],
                       'target_dataset': args_dict['target_dataset'],
                       'target_datasets': args_dict['target_datasets'],
                       'model_name': args_dict['model_name'],
                       'in_features': args_dict['in_features'], }
        args.__dict__.update(update_dict)
    args.dsbn = True if 'dsbn' in args.model_name else False  # set dsbn
    logger.info('\nArguments:\n' + pprint.pformat(vars(args), indent=4))

    model_options = io_utils.get_model_options_from_args(args, global_step)

    batch_size = args.batch_size
    num_classes = args.num_classes
    num_source_domains = len(args.source_datasets)
    num_target_domains = len(args.target_datasets)


    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        base_dir = os.path.dirname(args.model_path)
        tfboard_dir = os.path.join(base_dir, 'tfboard')
        if not os.path.isdir(tfboard_dir):
            os.makedirs(tfboard_dir)
        writer = SummaryWriter(tfboard_dir)
    ###################################################################################################################
    #                                               Data Loading                                                      #
    ###################################################################################################################

    source_test_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, source_dataset, 'test', args.jitter))
                            for source_dataset in args.source_datasets]
    target_test_datasets = [get_dataset("{}_{}_{}_{}".format(args.model_name, target_dataset, 'test', args.jitter))
                            for target_dataset in args.target_datasets]

    ###################################################################################################################
    #                                               Model Loading                                                     #
    ###################################################################################################################
    model = get_model(args.model_name, args.num_classes, args.in_features, pretrained=False)

    logger.info('Load trained parameters...')
    model.load_state_dict(model_state_dict)
    model.train(False)
    model.eval()
    model = model.cuda(args.gpu)

    # tfboard: write centroids
    if args.use_tfboard:
        centroids_filename = io_utils.get_centroids_filename(model_options)
        centroids_path = os.path.join(model_dir, centroids_filename)
        if os.path.isfile(centroids_path):
            logger.info('write centroids on tfboard: {}'.format(centroids_path))
            centroids_ckpt = torch.load(centroids_path)

            for i, centroids in enumerate(centroids_ckpt['srcs_centroids']):
                src_centroids = centroids['centroids'].cpu().data.numpy()
                writer.add_embedding(src_centroids, metadata=list(range(num_classes)),
                                     tag='src_centroids_{}'.format(args.source_datasets[i]), global_step=global_step)

            trg_centroids = centroids_ckpt['trg_centroids']['centroids'].cpu().data.numpy()
            writer.add_embedding(trg_centroids, metadata=list(range(num_classes)),
                                 tag='trg_centroids', global_step=global_step)

    logger.info('Start Evaluation')
    results = {'step': global_step}
    total_features = []
    total_labels = []

    # for d_idx, dataset in enumerate(target_test_datasets + source_test_datasets):
    for d_idx, dataset in enumerate(target_test_datasets):
        # dataloader
        dataloader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
                                     num_workers=args.num_workers, drop_last=False, pin_memory=True)
        pred_vals = []
        y_vals = []
        if args.use_tfboard:
            features = []

        with torch.no_grad():
            for i, (x_val, y_val) in enumerate(dataloader):
                x_val = x_val.cuda(args.gpu)
                y_val = y_val.cuda(args.gpu)

                if args.dsbn:
                    pred_val, f_val = model(x_val, torch.zeros_like(y_val), with_ft=True)
                else:
                    pred_val, f_val = model(x_val, with_ft=True)

                pred_vals.append(pred_val.cpu())
                y_vals.append(y_val.cpu())
                if args.use_tfboard:
                    features += [f_val.cpu().data.numpy()]

        pred_vals = torch.cat(pred_vals, 0)
        y_vals = torch.cat(y_vals, 0)
        test_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])
        val_accuracy_each_c = [(c_name, float(eval_utils.accuracy_of_c(pred_vals, y_vals,
                                                                       class_idx=c, topk=(1,))[0]))
                               for c, c_name in enumerate(dataset.classes)]
        # logging
        if d_idx <= num_target_domains:
            logger.info('{} Test Accuracy: {:.4f}%'.format(args.target_datasets[d_idx], 100 * test_accuracy))
            logger.info('\nEach class Accuracy of {}\n'.format(args.target_datasets[d_idx]) +
                        ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                                 for c_name, c_val_acc in val_accuracy_each_c]))
            logger.info('Evaluation mean Accuracy: {:.2f}%'.format(
                100 * float(torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))))
            if args.save_results:
                results.update({args.target_datasets[d_idx]: test_accuracy})
                results.update(
                    {args.target_datasets[d_idx] + '_' + c_name: c_val_acc for c_name, c_val_acc in val_accuracy_each_c})
        else:
            logger.info('{} Test Accuracy: {:.4f}'.format(args.source_datasets[d_idx - num_target_domains], test_accuracy))
            logger.info('\nEach class Accuracy of {}\n'.format(args.source_datasets[d_idx - num_target_domains]) +
                        ''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                                 for c_name, c_val_acc in val_accuracy_each_c]))
            logger.info('Evaluation mean Accuracy: {:.2f}%'.format(
                100 * float(torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))))
            if args.save_results:
                results.update({args.source_datasets[d_idx-num_target_domains]: test_accuracy})
                results.update(
                    {args.source_datasets[d_idx - num_target_domains] + '_' + c_name: c_val_acc for c_name, c_val_acc in
                     val_accuracy_each_c})

        if args.use_tfboard:
            features = np.concatenate(features, axis=0)
            y_vals_numpy = y_vals.numpy().astype(np.int)
            embed_features = features
            # u, s, vt = np.linalg.svd(features)
            # embed_features = np.dot(features, vt[:3, :].transpose())

            if d_idx <= num_target_domains:
                total_features += [embed_features]
                total_labels += [args.target_datasets[d_idx][0] + str(int(l)) for l in y_vals]
                writer.add_embedding(embed_features, metadata=y_vals_numpy, tag=args.target_datasets[d_idx],
                                     global_step=global_step)
            else:
                total_features += [embed_features]
                total_labels += [args.source_datasets[d_idx-num_target_domains][0] + str(int(l)) for l in y_vals]
                writer.add_embedding(embed_features, metadata=y_vals_numpy, tag=args.source_datasets[d_idx - num_target_domains],
                                     global_step=global_step)

    if args.use_tfboard:
        total_features = np.concatenate(total_features, axis=0)
        writer.add_embedding(total_features, metadata=list(total_labels),
                             tag='feat_embed_S:{}_T:{}'.format(args.source_dataset, args.target_dataset),
                             global_step=global_step)

    # save results
    if args.save_results:
        result_filename = 'evaluation_{:06d}.pth'.format(global_step)
        torch.save(results, os.path.join(model_dir, result_filename))

    if args.use_tfboard:
        writer.close()
Пример #13
0
def main():
    args = parse_args()
    stage = args.stage
    torch.cuda.set_device(args.gpu)
    writer = SummaryWriter()

    save_dir = args.save_dir
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    print('domain: ', args.trg_domain)

    num_domain = len(args.trg_domain)

    if (args.ssl):
        model = get_rot_model(args.model_name, num_domains=1)
        train_dataset = rot_dataset(args.data_root, num_domain,
                                    args.trg_domain, 'train')
        val_dataset = rot_dataset(args.data_root, num_domain, args.trg_domain,
                                  'val')
        test1_dataset = rot_dataset(args.data_root, 1, [args.trg_domain[0]],
                                    'test')
        if (len(args.trg_domain) > 1):
            test2_dataset = rot_dataset(args.data_root, 1,
                                        [args.trg_domain[1]], 'test')

    else:
        model = get_model(args.model_name, 65, 65, 1, pretrained=True)
        train_dataset = OFFICEHOME_multi(args.data_root,
                                         num_domain,
                                         args.trg_domain,
                                         transform=train_transform)
        val_dataset = OFFICEHOME_multi(args.data_root,
                                       num_domain,
                                       args.trg_domain,
                                       transform=val_transform)
        test1_dataset = OFFICEHOME_multi(args.data_root,
                                         1, [args.trg_domain[0]],
                                         transform=val_transform)
        if (len(args.trg_domain) > 1):
            test2_dataset = OFFICEHOME_multi(args.data_root,
                                             1, [args.trg_domain[1]],
                                             transform=val_transform)

    train_dataloader = util_data.DataLoader(train_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers,
                                            drop_last=True,
                                            pin_memory=True)

    train_dataloader_iter = enumerate(train_dataloader)

    model.train(True)
    model = model.cuda(args.gpu)

    ce_loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999))

    global best_accuracy
    global best_accuracies_each_c
    global best_mean_val_accuracies
    global best_total_val_accuracies
    best_accuracy = 0.0
    best_accuracies_each_c = []
    best_mean_val_accuracies = []
    best_total_val_accuracies = []

    for i in range(args.iters[0]):
        try:
            _, (x_s, y_s) = train_dataloader_iter.__next__()
        except StopIteration:
            train_dataloader_iter = enumerate(train_dataloader)
            _, (x_s, y_s) = train_dataloader_iter.__next__()
        optimizer.zero_grad()

        x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu)
        domain_idx = torch.ones(x_s.shape[0], dtype=torch.long).cuda(args.gpu)
        pred, f = model(x_s, 0 * domain_idx, with_ft=True)
        loss = ce_loss(pred, y_s)
        loss.backward()
        optimizer.step()

        if (i % 500 == 0 and i != 0):
            # print('------%d val start' % (i))
            model.eval()
            total_val_accuracies = []
            mean_val_accuracies = []
            val_accuracies_each_c = []
            model.eval()

            val_dataloader = util_data.DataLoader(val_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.num_workers,
                                                  drop_last=True,
                                                  pin_memory=True)
            val_dataloader_iter = enumerate(val_dataloader)

            pred_vals = []
            y_vals = []
            x_val = None
            y_val = None
            # print('------------------------dataload------------------------')
            with torch.no_grad():
                for j, (x_val, y_val) in val_dataloader_iter:
                    y_vals.append(y_val.cpu())
                    x_val = x_val.cuda(args.gpu)
                    y_val = y_val.cuda(args.gpu)

                    pred_val = model(x_val,
                                     0 * torch.ones_like(y_val),
                                     with_ft=False)

                    pred_vals.append(pred_val.cpu())

            pred_vals = torch.cat(pred_vals, 0)
            y_vals = torch.cat(y_vals, 0)
            total_val_accuracy = float(
                eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0])
            val_accuracy_each_c = [
                (c_name,
                 float(
                     eval_utils.accuracy_of_c(pred_vals,
                                              y_vals,
                                              class_idx=c,
                                              topk=(1, ))[0]))
                for c, c_name in enumerate(val_dataset.classes)
            ]

            mean_val_accuracy = float(
                torch.mean(
                    torch.FloatTensor(
                        [c_val_acc for _, c_val_acc in val_accuracy_each_c])))
            total_val_accuracies.append(total_val_accuracy)
            val_accuracies_each_c.append(val_accuracy_each_c)
            mean_val_accuracies.append(mean_val_accuracy)

            val_accuracy = float(
                torch.mean(torch.FloatTensor(total_val_accuracies)))
            print('%d th iteration accuracy: %f ' % (i, val_accuracy))
            del x_val, y_val, pred_val, pred_vals, y_vals
            del val_dataloader_iter

            model_dict = {'model': model.cpu().state_dict()}
            optimizer_dict = {'optimizer': optimizer.state_dict()}

            # save best checkpoint
            # io_utils.save_check(save_dir, i, model_dict, optimizer_dict, best=False)

            model.train(True)  # train mode
            if val_accuracy > best_accuracy:
                best_accuracy = val_accuracy
                best_accuracies_each_c = val_accuracies_each_c
                best_mean_val_accuracies = mean_val_accuracies
                best_total_val_accuracies = total_val_accuracies
                # print('%d iter val acc %.3f' % (i, val_accuracy))
                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}

                # save best checkpoint
                io_utils.save_check(save_dir,
                                    i,
                                    model_dict,
                                    optimizer_dict,
                                    best=True)

            model = model.cuda(args.gpu)

        if (i % 5000 == 0 and i != 0):
            print('%d iter complete' % (i))
            test(args, test1_dataset, model)
            if (len(args.trg_domain) > 1):
                test(args, test2_dataset, model)

    writer.flush()
    writer.close()

    model.eval()
    test(args, test1_dataset, model)
    if (len(args.trg_domain) > 1):
        test(args, test2_dataset, model)
Пример #14
0
def train(train_loader, model, criterion, optimizer, epoch, args, writer):
    # batch_time = AverageMeter("Time", ":6.3f")
    # data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.3f")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    #l = [batch_time, data_time, losses, top1, top5]
    l = [losses, top1, top5]
    progress = ProgressMeter(
        len(train_loader),
        l,
        prefix=f"Epoch: [{epoch}]",
    )

    # switch to train mode
    model.train()

    batch_size = train_loader.batch_size
    num_batches = len(train_loader)
    end = time.time()
    image0, target0 = None, None
    for i, (images, target) in tqdm.tqdm(enumerate(train_loader),
                                         ascii=True,
                                         total=len(train_loader)):
        # if i == 0:
        image0 = images
        target0 = target
        # measure data loading time
        # data_time.update(time.time() - end)

        if args.gpu is not None:
            image0 = image0.cuda(args.gpu, non_blocking=True)

        target0 = target0.cuda(args.gpu, non_blocking=True)
        l = 0
        a1 = 0
        a5 = 0
        for j in range(args.K):
            output = model(image0)
            loss = criterion(output, target0)
            acc1, acc5 = accuracy(output, target0, topk=(1, 5))
            l = l + loss
            a1 = a1 + acc1
            a5 = a5 + acc5
        l = l / args.K
        a1 = a1 / args.K
        a5 = a5 / args.K
        # measure accuracy and record loss
        # torch.Size([128, 3, 32, 32])
        # 128
        losses.update(l.item(), image0.size(0))
        top1.update(a1.item(), images.size(0))
        top5.update(a5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        if args.conv_type != "SFESubnetConv":
            l.backward()
        else:
            updateScoreDiff(model, l)
        # printModelScore(model, args)
        optimizer.step()

        # measure elapsed time
        # batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            t = (num_batches * epoch + i) * batch_size
            progress.display(i)
            progress.write_to_tensorboard(writer,
                                          prefix="train",
                                          global_step=t)

    return top1.avg, top5.avg
Пример #15
0
def train(train_loader, model, criterion, optimizer, epoch, args, writer):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.3f")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix=f"Epoch: [{epoch}]",
    )

    # switch to train mode
    model.train()

    batch_size = train_loader.batch_size
    num_batches = len(train_loader)
    end = time.time()
    for i, (images, target) in tqdm.tqdm(enumerate(train_loader),
                                         ascii=True,
                                         total=len(train_loader)):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)

        target = target.cuda(args.gpu, non_blocking=True)

        # Write scores and weights to tensorboard at beginning of every other epoch
        if args.histograms:
            if (i % (num_batches * batch_size) == 0) and (epoch % 2 == 0):
                for param_name in model.state_dict():
                    #print(param_name)
                    # Only write scores for now (not weights and batch norm parameters since the pytorch parms don't actually change)
                    #if 'score' not in param_name:
                    #if 'score' in param_name or 'weight' in param_name:
                    #print(param_name, model.state_dict()[param_name])
                    writer.add_histogram(param_name,
                                         model.state_dict()[param_name], epoch)

        # compute output
        output = model(images)

        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        #torch.nn.utils.clip_grad_norm_(model.parameters(),1)
        loss.backward()
        # EDITED
        #print(torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters()])))
        if args.grad_clip:
            torch.nn.utils.clip_grad_value_(model.parameters(), 1)
        #print(torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters()])))
        #for param_name in model.state_dict(): print(param_name, str(model.state_dict()[param_name])[:50])
        #torch.nn.utils.clip_grad_norm_(model.parameters(),1)
        # end
        optimizer.step()

        # Clamp updated scores to [-1,1] only when using binarized/quantized activations
        #for param_name in model.state_dict():
        #  if 'score' in param_name:
        #    #print(param_name)
        #    scores = model.state_dict()[param_name]
        #    #scores = torch.clamp(scores,min=-1.0,max=1.0)
        #    scores.clamp_(min=-1.0,max=1.0)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        #print(model.state_dict()['module.linear.3.scores'].grad)
        #params = list(model.parameters())
        #print(params[1].grad)

        if i % args.print_freq == 0:
            t = (num_batches * epoch + i) * batch_size
            progress.display(i)

            #_, predicted = torch.max(output, 1)
            progress.write_to_tensorboard(writer,
                                          prefix="train",
                                          global_step=t)

        # Write score gradients to tensorboard at end of every other epoch
        if args.histograms:
            if (i % (num_batches * batch_size) == 0) and (epoch % 2 == 0):
                #if ((i+1) % (num_batches-1) == 0) and (epoch % 2 == 0):
                params = list(model.parameters())
                param_names = list(model.state_dict())
                for j in range(len(params)):
                    if params[j].grad is not None:
                        # if 'score' in param_names[j] or 'weight' in param_names[j]:
                        # if 'score' not in param_name and params[j].grad is not None:
                        #print(param_names[j])
                        #print(params[j].grad)
                        writer.add_histogram(param_names[j] + '.grad',
                                             params[j].grad, epoch)
                    else:
                        writer.add_histogram(param_names[j] + '.grad', 0,
                                             epoch)
                #for param_name in model.state_dict():
                #  if 'score' in param_name:
                #    writer.add_histogram(param_name + '.grad', model.state_dict()[param_name].grad, epoch)
                #params = list(model.parameters())
                #for j in range(len(params)):
                #  writer.add_histogram('Layer' + str(j) + 'grad', params[j].grad, epoch)

    # Write final scores and weights to tensorboard
    if args.histograms:
        for param_name in model.state_dict():
            #writer.add_histogram(param_name, model.state_dict()[param_name], epoch)
            # Only write scores for now (not weights and batch norm parameters since the pytorch parms don't actually change)
            #if 'score' not in param_name:
            #print(param_name, model.state_dict()[param_name])
            writer.add_histogram(param_name,
                                 model.state_dict()[param_name], epoch)

    return top1.avg, top5.avg
Пример #16
0
def main():
    args = parse_args()
    stage = args.stage
    torch.cuda.set_device(args.gpu)

    if (stage == 1):
        save_dir = join(save_root, args.save_dir, 'stage1')
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        print('domain: ', args.trg_domain)
        model = get_model(args.model_name, 65, 65, 2, pretrained=True)

        num_domain = len(args.trg_domain)
        train_dataset = OFFICEHOME_multi(args.data_root, num_domain, args.trg_domain, transform=train_transform)
        val_dataset = OFFICEHOME_multi(args.data_root, num_domain, args.trg_domain, transform=val_transform)
        train(args, model, train_dataset, val_dataset, stage, save_dir)

        if (args.proceed):
            stage += 1

    if (stage == 2):
        save_dir = join(save_root, args.save_dir, 'stage2')
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        print('domain: ', args.src_domain)
        model = get_model(args.model_name, 65, 65, 2, pretrained=True)
        if (args.proceed):
            model.load_state_dict(torch.load(join(save_root, args.save_dir, 'stage1', 'best_model.ckpt'))['model'])
        else:
            model.load_state_dict(torch.load(join(save_root, args.model_path))['model'])

        for name, p in model.named_parameters():
            if ('fc' in name) or 'bns.1' in name:
                continue
            else:
                p.requires_grad = False
        torch.nn.init.xavier_uniform_(model.fc1.weight)
        torch.nn.init.xavier_uniform_(model.fc2.weight)
        num_domain = len(args.trg_domain)
        train_dataset = OFFICEHOME_multi(args.data_root, num_domain, args.src_domain, transform=train_transform)
        val_dataset = OFFICEHOME_multi(args.data_root, num_domain, args.src_domain, transform=val_transform)

        train(args, model, train_dataset, val_dataset, stage, save_dir)

        if (args.proceed):
            num_domain = len(args.trg_domain)
            val_dataset = OFFICEHOME_multi(args.data_root, num_domain, args.trg_domain, transform=val_transform)
            val_dataloader = util_data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True,
                                                  num_workers=args.num_workers, drop_last=True, pin_memory=True)
            val_dataloader_iter = enumerate(val_dataloader)
            domain_num = domain_dict[args.trg_domain[0]]
            pred_vals = []
            y_vals = []
            x_val = None
            y_val = None
            # print('------------------------dataload------------------------')
            with torch.no_grad():
                for j, (x_val, y_val) in val_dataloader_iter:
                    y_vals.append(y_val.cpu())
                    x_val = x_val.cuda(args.gpu)
                    y_val = y_val.cuda(args.gpu)

                    pred_val = model(x_val, domain_num * torch.ones_like(y_val), with_ft=False)

                    pred_vals.append(pred_val.cpu())

            pred_vals = torch.cat(pred_vals, 0)
            y_vals = torch.cat(y_vals, 0)
            total_val_accuracy = float(eval_utils.accuracy(pred_vals, y_vals, topk=(1,))[0])
Пример #17
0
def ps_test(args, teacher, student, val_dataset, domain_num):
    val_dataloader = util_data.DataLoader(val_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_workers,
                                          drop_last=True,
                                          pin_memory=True)
    val_dataloader_iter = enumerate(val_dataloader)

    val_accs_each_c = []
    student_accs_each_c = []

    pseu_ys = []
    pred_ys = []
    y_vals = []
    x_val = None
    y_val = None

    teacher.eval()
    student.eval()

    with torch.no_grad():
        for j, (x_val, y_val) in val_dataloader_iter:
            y_vals.append(y_val.cpu())
            x_val = x_val.cuda(args.gpu)
            y_val = y_val.cuda(args.gpu)

            # default number 1 for original dsbn implementation
            # 0:src 1: trg
            pseu_y = teacher(x_val, 1 * torch.ones_like(y_val),
                             with_ft=False).argmax(axis=1)
            pred_y = student(x_val,
                             domain_num * torch.ones_like(y_val),
                             with_ft=False)
            pseu_ys.append(pseu_y.cpu())
            pred_ys.append(pred_y.cpu())

    pred_ys = torch.cat(pred_ys, 0)
    pseu_ys = torch.cat(pseu_ys, 0)
    y_vals = torch.cat(y_vals, 0)

    val_acc = float(eval_utils.accuracy(pred_ys, y_vals, topk=(1, ))[0])
    val_acc_each_c = [(c_name,
                       float(
                           eval_utils.accuracy_of_c(pred_ys,
                                                    y_vals,
                                                    class_idx=c,
                                                    topk=(1, ))[0]))
                      for c, c_name in enumerate(val_dataset.classes)]
    student_acc = float(eval_utils.accuracy(pred_ys, pseu_ys, topk=(1, ))[0])
    student_acc_each_c = [(c_name,
                           float(
                               eval_utils.accuracy_of_c(pred_ys,
                                                        pseu_ys,
                                                        class_idx=c,
                                                        topk=(1, ))[0]))
                          for c, c_name in enumerate(val_dataset.classes)]
    val_accs_each_c.append(val_acc_each_c)
    student_accs_each_c.append(student_acc_each_c)

    del x_val, y_val, pred_y, pred_ys, pseu_y, pseu_ys, y_vals
    del val_dataloader_iter

    return student, val_acc, student_acc
Пример #18
0
def train(model,
          trainer,
          train_loader,
          epoch,
          logger,
          tb_logger,
          batch_size=opt.batch_size,
          print_freq=opt.print_freq):
    """ Train the model

    Outside of the typical training loops, `train()` incorporates other
    useful bookkeeping features and wrapper functions. This includes things
    like keeping track of accuracy, loss, batch time to wrapping optimizers
    and loss functions in the `trainer`. Be sure to reference `trainer.py`
    or `utils/eval_utils.py` if extra detail is needed.

    Args:
        model: Classification model
        trainer (Trainer): Training wrapper
        train_loader (torch.data.Dataloader): Generator data loading instance
        epoch (int): Current epoch
        logger (Logger): Logger. Used to display/log metrics
        tb_logger (SummaryWriter): Tensorboard Logger
        batch_size (int): Batch size
        print_freq (int): Print frequency

    Returns:
        None

    """
    criterion = trainer.criterion
    optimizer = trainer.optimizer

    # Initialize meter to bookkeep the following parameters
    meter = get_meter(meters=['batch_time', 'data_time', 'loss', 'acc'])

    # Switch to training mode
    model.train(True)

    end = time.time()
    for i, batch in enumerate(train_loader):
        # process batch items: images, labels
        img = to_cuda(batch[CONST.IMG], trainer.computing_device)
        target = to_cuda(batch[CONST.LBL],
                         trainer.computing_device,
                         label=True)
        id = batch[CONST.ID]

        # measure data loading time
        meter['data_time'].update(time.time() - end)

        # compute output
        end = time.time()
        logits = model(img)
        loss = criterion(logits, target)
        acc = accuracy(logits, target)

        # update metrics
        meter['acc'].update(acc, batch_size)
        meter['loss'].update(loss, batch_size)

        # compute gradient and do sgd step
        optimizer.zero_grad()
        loss.backward()

        if i % print_freq == 0:
            log = 'TRAIN [{:02d}][{:2d}/{:2d}] TIME {:10} DATA {:10} ACC {:10} LOSS {:10}'.\
                format(epoch, i, len(train_loader),
                       "{t.val:.3f} ({t.avg:.3f})".format(t=meter['batch_time']),
                       "{t.val:.3f} ({t.avg:.3f})".format(t=meter['data_time']),
                       "{t.val:.3f} ({t.avg:.3f})".format(t=meter['acc']),
                       "{t.val:.3f} ({t.avg:.3f})".format(t=meter['loss'])
                                )
            logger.info(log)

            tb_logger.add_scalar('train/loss', meter['loss'].val,
                                 epoch * len(train_loader) + i)
            tb_logger.add_scalar('train/accuracy', meter['acc'].val,
                                 epoch * len(train_loader) + i)
            tb_logger.add_scalar('data_time', meter['data_time'].val,
                                 epoch * len(train_loader) + i)
            tb_logger.add_scalar(
                'compute_time',
                meter['batch_time'].val - meter['data_time'].val,
                epoch * len(train_loader) + i)

        optimizer.step()

        # measure elapsed time
        meter['batch_time'].update(time.time() - end)
        end = time.time()

    tb_logger.add_scalar('train-epoch/loss', meter['loss'].avg, epoch)
    tb_logger.add_scalar('train-epoch/accuracy', meter['acc'].avg, epoch)

    return meter['loss'].avg, meter['acc'].avg
Пример #19
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu)
    stage = args.stage

    global best_accuracy
    global best_accuracies_each_c
    global best_mean_val_accuracies
    global best_total_val_accuracies

    svhn_train = SVHN(root='/data/jihun/SVHN',
                      transform=svhn_transform,
                      download=True)
    svhn_val = SVHN(root='/data/jihun/SVHN',
                    split='test',
                    transform=svhn_transform,
                    download=True)
    mnist_train = MNIST('/data/jihun/MNIST',
                        train=True,
                        transform=mnist_transform,
                        download=True)
    mnist_val = MNIST('/data/jihun/MNIST',
                      train=False,
                      transform=mnist_transform,
                      download=True)

    if (stage == 1):
        save_dir = join(save_root, args.save_dir, 'stage1')
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        if (args.trg_domain == 'mnist'):
            train_dataset = mnist_train
            val_dataset = mnist_val
        else:
            train_dataset = svhn_train
            val_dataset = svhn_val

        train_dataloader = util_data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=5,
                                                drop_last=True,
                                                pin_memory=True)

        train_dataloader_iters = enumerate(train_dataloader)
        model = DSBNLeNet(num_classes=10, in_features=0, num_domains=2)
        model.train(True)
        model = model.cuda(args.gpu)
        optimizer = optim.Adam(model.parameters(), betas=(0.9, 0.999))
        ce_loss = nn.CrossEntropyLoss()
        domain_num = 0

        best_accuracy = 0.0
        best_accuracies_each_c = []
        best_mean_val_accuracies = []
        best_total_val_accuracies = []

        writer = SummaryWriter()
        for i in range(args.iters[0]):
            try:
                _, (x_s, y_s) = train_dataloader_iters.__next__()
            except StopIteration:
                train_dataloader_iters = enumerate(train_dataloader)
                _, (x_s, y_s) = train_dataloader_iters.__next__()

            optimizer.zero_grad()
            # lr_scheduler(optimizer, i)

            x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu)
            # x_s = x_s.cuda(args.gpu)
            domain_idx = torch.ones(x_s.shape[0],
                                    dtype=torch.long).cuda(args.gpu)
            pred, f = model(x_s, domain_num * domain_idx, with_ft=True)
            loss = ce_loss(pred, y_s)
            # print(loss)
            writer.add_scalar("Train Loss", loss, i)
            loss.backward()
            optimizer.step()

            if (i % 500 == 0 and i != 0):
                # print('------%d val start' % (i))
                model.eval()
                total_val_accuracies = []
                mean_val_accuracies = []
                val_accuracies_each_c = []
                model.eval()

                val_dataloader = util_data.DataLoader(
                    val_dataset,
                    batch_size=args.batch_size,
                    shuffle=True,
                    num_workers=args.num_workers,
                    drop_last=True,
                    pin_memory=True)
                val_dataloader_iter = enumerate(val_dataloader)

                pred_vals = []
                y_vals = []
                x_val = None
                y_val = None

                with torch.no_grad():
                    for j, (x_val, y_val) in val_dataloader_iter:
                        y_vals.append(y_val.cpu())
                        x_val = x_val.cuda(args.gpu)
                        y_val = y_val.cuda(args.gpu)

                        pred_val = model(x_val,
                                         domain_num * torch.ones_like(y_val),
                                         with_ft=False)

                        pred_vals.append(pred_val.cpu())

                pred_vals = torch.cat(pred_vals, 0)
                y_vals = torch.cat(y_vals, 0)
                total_val_accuracy = float(
                    eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0])
                val_accuracy_each_c = [
                    (c_name,
                     float(
                         eval_utils.accuracy_of_c(pred_vals,
                                                  y_vals,
                                                  class_idx=c,
                                                  topk=(1, ))[0]))
                    for c, c_name in enumerate(val_dataset.classes)
                ]

                mean_val_accuracy = float(
                    torch.mean(
                        torch.FloatTensor([
                            c_val_acc for _, c_val_acc in val_accuracy_each_c
                        ])))
                total_val_accuracies.append(total_val_accuracy)
                val_accuracies_each_c.append(val_accuracy_each_c)
                mean_val_accuracies.append(mean_val_accuracy)

                val_accuracy = float(
                    torch.mean(torch.FloatTensor(total_val_accuracies)))
                print('%d th iteration accuracy: %f ' % (i, val_accuracy))
                del x_val, y_val, pred_val, pred_vals, y_vals
                del val_dataloader_iter

                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}

                # save best checkpoint
                io_utils.save_check(save_dir,
                                    i,
                                    model_dict,
                                    optimizer_dict,
                                    best=False)

                model.train(True)  # train mode
                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    best_accuracies_each_c = val_accuracies_each_c
                    best_mean_val_accuracies = mean_val_accuracies
                    best_total_val_accuracies = total_val_accuracies
                    # print('%d iter val acc %.3f' % (i, val_accuracy))
                    model_dict = {'model': model.cpu().state_dict()}
                    optimizer_dict = {'optimizer': optimizer.state_dict()}

                    # save best checkpoint
                    io_utils.save_check(save_dir,
                                        i,
                                        model_dict,
                                        optimizer_dict,
                                        best=True)

                model = model.cuda(args.gpu)
        if args.proceed:
            stage += 1
    if (stage == 2):
        save_dir = join(save_root, args.save_dir, 'stage2')
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        if (args.src_domain == 'mnist'):
            train_dataset = mnist_train
            val_dataset = mnist_val
        else:
            train_dataset = svhn_train
            val_dataset = svhn_val

        train_dataloader = util_data.DataLoader(train_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=args.num_workers,
                                                drop_last=True,
                                                pin_memory=True)
        train_dataloader_iters = enumerate(train_dataloader)

        model = DSBNLeNet(num_classes=10, in_features=0, num_domains=2)
        if (args.proceed):
            model.load_state_dict(
                torch.load(
                    join(save_root, args.save_dir, 'stage1',
                         'best_model.ckpt'))['model'])
        else:
            model.load_state_dict(
                torch.load(save_root, args.model_path)['model'])

        for name, p in model.named_parameters():
            if ('fc' in name) or 'bns.1' in name:
                p.requires_grad = True
                continue
            else:
                p.requires_grad = False

        torch.nn.init.xavier_uniform_(model.fc1.weight)
        torch.nn.init.xavier_uniform_(model.fc2.weight)

        model.train(True)
        model = model.cuda(args.gpu)

        params = get_optimizer_params(model,
                                      args.learning_rate,
                                      weight_decay=args.weight_decay,
                                      double_bias_lr=True,
                                      base_weight_factor=0.1)

        optimizer = optim.Adam(params, betas=(0.9, 0.999))
        ce_loss = nn.CrossEntropyLoss()

        writer = SummaryWriter()
        domain_num = stage - 1
        print('domain_num, stage: ', domain_num, stage)

        best_accuracy = 0.0
        best_accuracies_each_c = []
        best_mean_val_accuracies = []
        best_total_val_accuracies = []

        for i in range(args.iters[stage - 1]):
            try:
                _, (x_s, y_s) = train_dataloader_iters.__next__()
            except StopIteration:
                train_dataloader_iters = enumerate(train_dataloader)
                _, (x_s, y_s) = train_dataloader_iters.__next__()
            optimizer.zero_grad()
            # lr_scheduler(optimizer, i)

            x_s, y_s = x_s.cuda(args.gpu), y_s.cuda(args.gpu)
            domain_idx = torch.ones(x_s.shape[0],
                                    dtype=torch.long).cuda(args.gpu)
            pred, f = model(x_s, domain_num * domain_idx, with_ft=True)
            loss = ce_loss(pred, y_s)
            writer.add_scalar("Train Loss", loss, i)
            loss.backward()
            optimizer.step()

            if (i % 500 == 0 and i != 0):
                # print('------%d val start' % (i))
                model.eval()
                total_val_accuracies = []
                mean_val_accuracies = []
                val_accuracies_each_c = []
                model.eval()

                val_dataloader = util_data.DataLoader(
                    val_dataset,
                    batch_size=args.batch_size,
                    shuffle=True,
                    num_workers=args.num_workers,
                    drop_last=True,
                    pin_memory=True)
                val_dataloader_iter = enumerate(val_dataloader)

                pred_vals = []
                y_vals = []
                x_val = None
                y_val = None
                # print('------------------------dataload------------------------')
                with torch.no_grad():
                    for j, (x_val, y_val) in val_dataloader_iter:
                        y_vals.append(y_val.cpu())
                        x_val = x_val.cuda(args.gpu)
                        y_val = y_val.cuda(args.gpu)

                        pred_val = model(x_val,
                                         domain_num * torch.ones_like(y_val),
                                         with_ft=False)

                        pred_vals.append(pred_val.cpu())

                pred_vals = torch.cat(pred_vals, 0)
                y_vals = torch.cat(y_vals, 0)
                total_val_accuracy = float(
                    eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0])
                val_accuracy_each_c = [
                    (c_name,
                     float(
                         eval_utils.accuracy_of_c(pred_vals,
                                                  y_vals,
                                                  class_idx=c,
                                                  topk=(1, ))[0]))
                    for c, c_name in enumerate(val_dataset.classes)
                ]

                mean_val_accuracy = float(
                    torch.mean(
                        torch.FloatTensor([
                            c_val_acc for _, c_val_acc in val_accuracy_each_c
                        ])))
                total_val_accuracies.append(total_val_accuracy)
                val_accuracies_each_c.append(val_accuracy_each_c)
                mean_val_accuracies.append(mean_val_accuracy)

                val_accuracy = float(
                    torch.mean(torch.FloatTensor(total_val_accuracies)))
                print('%d th iteration accuracy: %f ' % (i, val_accuracy))
                del x_val, y_val, pred_val, pred_vals, y_vals
                del val_dataloader_iter

                model_dict = {'model': model.cpu().state_dict()}
                optimizer_dict = {'optimizer': optimizer.state_dict()}

                # save best checkpoint
                io_utils.save_check(save_dir,
                                    i,
                                    model_dict,
                                    optimizer_dict,
                                    best=False)

                model.train(True)  # train mode
                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    best_accuracies_each_c = val_accuracies_each_c
                    best_mean_val_accuracies = mean_val_accuracies
                    best_total_val_accuracies = total_val_accuracies
                    # print('%d iter val acc %.3f' % (i, val_accuracy))
                    model_dict = {'model': model.cpu().state_dict()}
                    optimizer_dict = {'optimizer': optimizer.state_dict()}

                    # save best checkpoint
                    io_utils.save_check(save_dir,
                                        i,
                                        model_dict,
                                        optimizer_dict,
                                        best=True)

                model = model.cuda(args.gpu)
        if (args.proceed):
            stage += 1

    if (stage == 3):

        if (args.trg_domain == 'mnist'):
            val_dataset = mnist_val
        else:
            val_dataset = svhn_val

        val_dataloader = util_data.DataLoader(val_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=5,
                                              drop_last=True,
                                              pin_memory=True)

        val_dataloader_iter = enumerate(val_dataloader)

        model = DSBNLeNet(num_classes=10, in_features=0, num_domains=2)
        if (args.proceed):
            model.load_state_dict(
                torch.load(
                    join(save_root, args.save_dir, 'stage2',
                         'best_model.ckpt'))['model'])
        else:
            model.load_state_dict(
                torch.load(save_root, args.model_path)['model'])
        model = model.cuda(args.gpu)

        pred_vals = []
        y_vals = []
        domain_num = 0

        with torch.no_grad():
            for j, (x_val, y_val) in val_dataloader_iter:
                y_vals.append(y_val.cpu())
                x_val = x_val.cuda(args.gpu)
                y_val = y_val.cuda(args.gpu)

                pred_val = model(x_val,
                                 domain_num * torch.ones_like(y_val),
                                 with_ft=False)

                pred_vals.append(pred_val.cpu())

        pred_vals = torch.cat(pred_vals, 0)
        y_vals = torch.cat(y_vals, 0)
        total_val_accuracy = float(
            eval_utils.accuracy(pred_vals, y_vals, topk=(1, ))[0])
        val_accuracy_each_c = [(c_name,
                                float(
                                    eval_utils.accuracy_of_c(pred_vals,
                                                             y_vals,
                                                             class_idx=c,
                                                             topk=(1, ))[0]))
                               for c, c_name in enumerate(val_dataset.classes)]

        print(total_val_accuracy)
        print(val_accuracy_each_c)
Пример #20
0
def run_session(train,
                test,
                LB,
                data_cols,
                graph,
                session_state,
                num_itrs,
                name,
                batch_size,
                k_prob=1.0,
                mute=False,
                record=False):
    """Run tensorflow session

    The session is essentially the training loop, that runs the graph that was initialized in the main.

    Args:
        train (pd.DataFrame): Training set
        test (pd.DataFrame): Test set
        LB (sklearn.LabelBinarizer): One hot label encoded
        data_cols (pd.DataFrame.columns): Feature columns
        graph: Tensorflow Graph
        session_state (dict): Session state dictionary containing loss, optimizer, and prediction
        num_itrs (int): Number of epochs
        name (str): Model type
        batch_size (int): Batch size
        k_prob (float):
        mute:
        record:

    Returns:
        None

    """
    # Stats container
    """Use this to access the accuracy and predictions, given the model_name
    >> test_preds['RNN']
    0.983
    """
    acc_over_time = {}
    test_preds = {}
    loss_over_time = {}

    start = timer()
    test_labels = LB.transform(test['label'])

    with tf.Session(graph=graph) as session:
        if record:
            merged = tf.merge_all_summaries()
            writer = tf.train.SummaryWriter("/tmp/tensorflowlogs",
                                            session.graph)
        #tf.initialize_all_variables().run()
        tf.global_variables_initializer().run()

        print("Initialized")
        accu = []
        loss = []

        for iteration in range(num_itrs):

            # get batch
            train_batch = train.sample(batch_size)
            t_d = train_batch[data_cols]
            t_l = LB.transform(train_batch['label'])

            # make feed dict
            feed_dict = init_feed_dict(mode='train',
                                       data=t_d,
                                       label=t_l,
                                       prob=k_prob)

            # run model on batch
            _, l, predictions = session.run([
                session_state['optimizer'], session_state['loss'],
                session_state['prediction']
            ],
                                            feed_dict=feed_dict)

            # mid model accuracy checks
            if (iteration % 1000 == 0) and not mute:
                loss.append(l)
                print("\tMinibatch loss at iteration {}: {}".format(
                    iteration, l))
                print("\tMinibatch accuracy: {:.1f}".format(
                    accuracy(predictions, t_l)))
            if (iteration % 5000 == 0) and not mute:
                print("Test accuracy: {:.1f}".format(
                    test_accuracy(session,
                                  test_data=test,
                                  test_labels=test_labels,
                                  data_cols=data_cols,
                                  prediction=session_state['prediction'],
                                  during=True)))
            if (iteration % 1000 == 0) and not mute:
                accu.append(
                    tuple([
                        iteration,
                        test_accuracy(session,
                                      test_data=test,
                                      test_labels=test_labels,
                                      data_cols=data_cols,
                                      prediction=session_state['prediction'],
                                      during=True)
                    ]))

        # record accuracy and predictions
        test_preds[name] = test_accuracy(
            session,
            test_data=test,
            test_labels=test_labels,
            data_cols=data_cols,
            prediction=session_state['prediction'],
            during=False)
        print("Final Test accuracy: {:.1f}".format(
            accuracy(test_preds[name], test_labels)))
        end = timer()
        test_preds[name] = test_preds[name].ravel()
        acc_over_time[name] = accu
        loss_over_time[name] = loss
        print("time taken: {0} minutes {1:.1f} seconds".format(
            (end - start) // 60, (end - start) % 60))

        # Save data
        dump_data(data=acc_over_time)
        dump_data(data=test_preds, fname='data/test_preds.p')
        # dump_data(data=test_labels, fname='data/test_labels.p')
        dump_data(data=loss_over_time, fname='data/dataset_loss.p')