Пример #1
0
def train_model(data_loader, model, criterion, optimizer, epoch, log,
            print_freq=200, use_cuda=True):
    # train function (forward, backward, update)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    # switch to train mode
    model.train()

    end = time.time()
    for iteration, (input, target) in enumerate(data_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            target = target.cuda()
            input = input.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        if len(target.shape) > 1:
            target = torch.argmax(target, dim=-1)
        prec1, = accuracy(output.data, target, topk=(1,))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if iteration % print_freq == 0:
            print_log('  Epoch: [{:03d}][{:03d}/{:03d}]   '
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})   '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})   '
                        'Loss {loss.val:.4f} ({loss.avg:.4f})   '
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})   '.format(
                        epoch, iteration, len(data_loader), batch_time=batch_time,
                        data_time=data_time, loss=losses, top1=top1) + time_string(), log)

    print_log('  **Train** Prec@1 {top1.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, error1=100-top1.avg), log)
    return top1.avg, losses.avg
Пример #2
0
def main():
    args = parse_arguments()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.use_cuda:
        torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True

    model_path = get_model_path(args.dataset, args.arch, args.seed)

    # Init logger
    log_file_name = os.path.join(model_path, 'log.txt')
    print("Log file: {}".format(log_file_name))
    log = open(log_file_name, 'w')
    print_log('model path : {}'.format(model_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    for key, value in state.items():
        print_log("{} : {}".format(key, value), log)
    print_log("Random Seed: {}".format(args.seed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("Torch  version : {}".format(torch.__version__), log)
    print_log("Cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Data specifications for the webistes dataset
    mean = [0., 0., 0.]
    std = [1., 1., 1.]
    input_size = 224
    num_classes = 4

    # Dataset
    traindir = os.path.join(WEBSITES_DATASET_PATH, 'train')
    valdir = os.path.join(WEBSITES_DATASET_PATH, 'val')

    train_transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    test_transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    data_train = dset.ImageFolder(root=traindir, transform=train_transform)
    data_test = dset.ImageFolder(root=valdir, transform=test_transform)

    # Dataloader
    data_train_loader = torch.utils.data.DataLoader(data_train,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)
    data_test_loader = torch.utils.data.DataLoader(data_test,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    # Network
    if args.arch == "vgg16":
        net = models.vgg16(pretrained=True)
    elif args.arch == "vgg19":
        net = models.vgg19(pretrained=True)
    elif args.arch == "resnet18":
        net = models.resnet18(pretrained=True)
    elif args.arch == "resnet50":
        net = models.resnet50(pretrained=True)
    elif args.arch == "resnet101":
        net = models.resnet101(pretrained=True)
    elif args.arch == "resnet152":
        net = models.resnet152(pretrained=True)
    else:
        raise ValueError("Network {} not supported".format(args.arch))

    if num_classes != 1000:
        net = manipulate_net_architecture(model_arch=args.arch,
                                          net=net,
                                          num_classes=num_classes)

    # Loss function
    if args.loss_function == "ce":
        criterion = torch.nn.CrossEntropyLoss()
    else:
        raise ValueError

    # Cuda
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    # Optimizer
    momentum = 0.9
    decay = 5e-4
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.learning_rate,
                                momentum=momentum,
                                weight_decay=decay,
                                nesterov=True)

    recorder = RecorderMeter(args.epochs)
    start_time = time.time()
    epoch_time = AverageMeter()

    # Main loop
    for epoch in range(args.epochs):
        current_learning_rate = adjust_learning_rate(args.learning_rate,
                                                     momentum, optimizer,
                                                     epoch, args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train_model(data_loader=data_train_loader,
                                           model=net,
                                           criterion=criterion,
                                           optimizer=optimizer,
                                           epoch=epoch,
                                           log=log,
                                           print_freq=200,
                                           use_cuda=True)

        # evaluate on test set
        print_log("Validation on test dataset:", log)
        val_acc, val_loss = validate(data_test_loader,
                                     net,
                                     criterion,
                                     log=log,
                                     use_cuda=args.use_cuda)
        recorder.update(epoch, train_los, train_acc, val_loss, val_acc)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': copy.deepcopy(args),
            }, model_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(model_path, 'curve.png'))

    log.close()
Пример #3
0
def train_cae(trainloader, model, class_name, testloader, y_train, device,
              args):
    """
    model train function.
    :param trainloader:
    :param model:
    :param class_name:
    :param testloader:
    :param y_train: numpy array, sample normal/abnormal labels, [1 1 1 1 0 0] like, original sample size.
    :param device: cpu or gpu:0/1/...
    :param args:
    :return:
    """
    global_step = 0
    losses = AverageMeter()
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(1, args.epochs + 1):
        model.train()

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print('{:3d}/{:3d} ----- {:s} {:s}'.format(epoch, args.epochs,
                                                   time_string(), need_time))

        mse = nn.MSELoss(reduction='mean')  # default

        lr = 0.1 / pow(2, np.floor(epoch / args.lr_schedule))
        logger.add_scalar(class_name + "/lr", lr, epoch)

        if args.optimizer == 'SGD':
            optimizer = optim.SGD(model.parameters(),
                                  lr=lr,
                                  weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(),
                                   eps=1e-7,
                                   weight_decay=0.0005)
        for batch_idx, (input, _, _) in enumerate(trainloader):
            optimizer.zero_grad()
            input = input.to(device)

            _, output = model(input)
            loss = mse(input, output)
            losses.update(loss.item(), 1)

            logger.add_scalar(class_name + '/loss', losses.avg, global_step)

            global_step = global_step + 1
            loss.backward()
            optimizer.step()

        # print losses
        print('Epoch: [{} | {}], loss: {:.4f}'.format(epoch, args.epochs,
                                                      losses.avg))

        # log images
        if epoch % args.log_img_steps == 0:
            os.makedirs(os.path.join(RESULTS_DIR, class_name), exist_ok=True)
            fpath = os.path.join(RESULTS_DIR, class_name,
                                 'pretrain_epoch_' + str(epoch) + '.png')
            visualize(input, output, fpath, num=32)

        # test while training
        if epoch % args.log_auc_steps == 0:
            rep, losses_result = test(testloader, model, class_name, args,
                                      device, epoch)

            centroid = torch.mean(rep, dim=0, keepdim=True)

            losses_result = losses_result - losses_result.min()
            losses_result = losses_result / (1e-8 + losses_result.max())
            scores = 1 - losses_result
            auroc_rec = roc_auc_score(y_train, scores)

            _, p = dec_loss_fun(rep, centroid)
            score_p = p[:, 0]
            auroc_dec = roc_auc_score(y_train, score_p)

            print("Epoch: [{} | {}], auroc_rec: {:.4f}; auroc_dec: {:.4f}".
                  format(epoch, args.epochs, auroc_rec, auroc_dec))

            logger.add_scalar(class_name + '/auroc_rec', auroc_rec, epoch)
            logger.add_scalar(class_name + '/auroc_dec', auroc_dec, epoch)

        # time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
Пример #4
0
def train_iae(trainloader, model, class_name, testloader, y_train, device,
              args):
    """
    model train function.
    :param trainloader:
    :param model:
    :param class_name:
    :param testloader:
    :param y_train: numpy array, sample normal/abnormal labels, [1 1 1 1 0 0] like, original sample size.
    :param device: cpu or gpu:0/1/...
    :param args:
    :return:
    """
    global_step = 0
    losses = AverageMeter()
    l2_losses = AverageMeter()
    svdd_losses = AverageMeter()

    start_time = time.time()
    epoch_time = AverageMeter()

    svdd_loss = torch.tensor(0, device=device)
    R = torch.tensor(0, device=device)
    c = torch.randn(256, device=device)

    for epoch in range(1, args.epochs + 1):
        model.train()

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print('{:3d}/{:3d} ----- {:s} {:s}'.format(epoch, args.epochs,
                                                   time_string(), need_time))

        mse = nn.MSELoss(reduction='mean')  # default

        lr = 0.1 / pow(2, np.floor(epoch / args.lr_schedule))
        logger.add_scalar(class_name + "/lr", lr, epoch)

        if args.optimizer == 'sgd':
            optimizer = optim.SGD(model.parameters(),
                                  lr=lr,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(model.parameters(),
                                   eps=1e-7,
                                   weight_decay=args.weight_decay)
        else:
            print('not implemented.')

        for batch_idx, (input, _, _) in enumerate(trainloader):
            optimizer.zero_grad()
            input = input.to(device)

            reps, output = model(input)

            if epoch > args.pretrain_epochs:
                dist = torch.sum((reps - c)**2, dim=1)
                scores = dist - R**2
                svdd_loss = args.para_lambda * (
                    R**2 + (1 / args.para_nu) *
                    torch.mean(torch.max(torch.zeros_like(scores), scores)))

            l2_loss = mse(input, output)

            loss = l2_loss + svdd_loss

            l2_losses.update(l2_loss.item(), 1)
            svdd_losses.update(svdd_loss.item(), 1)
            losses.update(loss.item(), 1)

            logger.add_scalar(class_name + '/l2_loss', l2_losses.avg,
                              global_step)
            logger.add_scalar(class_name + '/svdd_loss', svdd_losses.avg,
                              global_step)
            logger.add_scalar(class_name + '/loss', losses.avg, global_step)

            logger.add_scalar(class_name + '/R', R.data, global_step)

            global_step = global_step + 1
            loss.backward()
            optimizer.step()

            # Update hypersphere radius R on mini-batch distances
            if epoch > args.pretrain_epochs:
                R.data = torch.tensor(get_radius(dist, args.para_nu),
                                      device=device)

        # print losses
        print('Epoch: [{} | {}], loss: {:.4f}'.format(epoch, args.epochs,
                                                      losses.avg))

        # log images
        if epoch % args.log_img_steps == 0:
            os.makedirs(os.path.join(RESULTS_DIR, class_name), exist_ok=True)
            fpath = os.path.join(RESULTS_DIR, class_name,
                                 'pretrain_epoch_' + str(epoch) + '.png')
            visualize(input, output, fpath, num=32)

        # test while training
        if epoch % args.log_auc_steps == 0:
            rep, losses_result = test(testloader, model, class_name, args,
                                      device, epoch)

            centroid = torch.mean(rep, dim=0, keepdim=True)

            losses_result = losses_result - losses_result.min()
            losses_result = losses_result / (1e-8 + losses_result.max())
            scores = 1 - losses_result
            auroc_rec = roc_auc_score(y_train, scores)

            _, p = dec_loss_fun(rep, centroid)
            score_p = p[:, 0]
            auroc_dec = roc_auc_score(y_train, score_p)

            print("Epoch: [{} | {}], auroc_rec: {:.4f}; auroc_dec: {:.4f}".
                  format(epoch, args.epochs, auroc_rec, auroc_dec))

            logger.add_scalar(class_name + '/auroc_rec', auroc_rec, epoch)
            logger.add_scalar(class_name + '/auroc_dec', auroc_dec, epoch)

        # initial centroid c before pretrain finished
        if epoch == args.pretrain_epochs:
            rep, losses_result = test(testloader, model, class_name, args,
                                      device, epoch)
            c = update_center_c(rep)

        # time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
Пример #5
0
def train(data_loader,
            model,
            criterion,
            optimizer,
            epsilon,
            num_iterations,
            targeted,
            target_class,
            log,
            print_freq=200,
            use_cuda=True):
    # train function (forward, backward, update)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.module.generator.train()
    model.module.target_model.eval()

    end = time.time()

    data_iterator = iter(data_loader)

    iteration=0
    while (iteration<num_iterations):
        try:
            input, target = next(data_iterator)
        except StopIteration:
            # StopIteration is thrown if dataset ends
            # reinitialize data loader
            data_iterator = iter(data_loader)
            input, target = next(data_iterator)

        if targeted:
            target = torch.ones(input.shape[0], dtype=torch.int64) * target_class
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            target = target.cuda()
            input = input.cuda()

        # compute output
        if model.module._get_name() == "Inception3":
            output, aux_output = model(input)
            loss1 = criterion(output, target)
            loss2 = criterion(aux_output, target)
            loss = loss1 + 0.4*loss2
        else:
            output = model(input)
            loss = criterion(output, target)

        # measure accuracy and record loss
        if len(target.shape) > 1:
            target = torch.argmax(target, dim=-1)
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Projection
        model.module.generator.uap.data = torch.clamp(model.module.generator.uap.data, -epsilon, epsilon)
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if iteration % print_freq == 0:
            print_log('  Iteration: [{:03d}/{:03d}]   '
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})   '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})   '
                        'Loss {loss.val:.4f} ({loss.avg:.4f})   '
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})   '
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})   '.format(
                        iteration, num_iterations, batch_time=batch_time,
                        data_time=data_time, loss=losses, top1=top1, top5=top5) + time_string(), log)

        iteration+=1
    print_log('  **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1,
                                                                                                    top5=top5,
                                                                                                    error1=100-top1.avg), log)
Пример #6
0
def train_half_half(sources_data_loader, others_data_loader,
                    model, target_model, criterion, optimizer, epsilon, num_iterations, log,
                    print_freq=200, use_cuda=True, patch=False):
    # train function (forward, backward, update)
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.module.generator.train()
    model.module.target_model.eval()
    target_model.eval()

    end = time.time()

    sources_data_iterator = iter(sources_data_loader)
    others_data_iterator = iter(others_data_loader)

    iteration=0
    while (iteration<num_iterations):
        try:
            sources_input, sources_target = next(sources_data_iterator)
        except StopIteration:
            # StopIteration is thrown if dataset ends
            # reinitialize data loader
            sources_data_iterator = iter(sources_data_loader)
            sources_input, sources_target = next(sources_data_iterator)

        try:
            others_input, others_target = next(others_data_iterator)
        except StopIteration:
            # StopIteration is thrown if dataset ends
            # reinitialize data loader
            others_data_iterator = iter(others_data_loader)
            others_input, others_target = next(others_data_iterator)

        # Concat the two batches
        input = torch.cat([sources_input, others_input], dim=0)
        target = torch.cat([sources_target, others_target], dim=0)

        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            target = target.cuda()
            input = input.cuda()

        # compute output
        output = model(input)
        target_model_output = target_model(input)
        loss = criterion(output, target_model_output, target)

        # measure accuracy and record loss
        if len(target.shape) > 1:
            target = torch.argmax(target, dim=-1)
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Project to l-infinity ball
        if patch:
            model.module.generator.uap.data = torch.clamp(model.module.generator.uap.data, 0, epsilon)
        else:
            model.module.generator.uap.data = torch.clamp(model.module.generator.uap.data, -epsilon, epsilon)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if iteration % print_freq == 0:
            print_log('  Iteration: [{:03d}/{:03d}]   '
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})   '
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})   '
                        'Loss {loss.val:.4f} ({loss.avg:.4f})   '
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})   '
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})   '.format(
                        iteration, num_iterations, batch_time=batch_time,
                        data_time=data_time, loss=losses, top1=top1, top5=top5) + time_string(), log)

        iteration+=1

    print_log('  **Train** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1,
                                                                                                    top5=top5,
                                                                                                    error1=100-top1.avg), log)
Пример #7
0
def main():
    args = parse_arguments()

    random.seed(args.pretrained_seed)
    torch.manual_seed(args.pretrained_seed)
    if args.use_cuda:
        torch.cuda.manual_seed_all(args.pretrained_seed)
    cudnn.benchmark = True

    # get a path for saving the model to be trained
    model_path = get_model_path(dataset_name=args.pretrained_dataset,
                                network_arch=args.pretrained_arch,
                                random_seed=args.pretrained_seed)

    # Init logger
    log_file_name = os.path.join(model_path, 'log_seed_{}.txt'.format(args.pretrained_seed))
    print("Log file: {}".format(log_file_name))
    log = open(log_file_name, 'w')
    print_log('save path : {}'.format(model_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    for key, value in state.items():
        print_log("{} : {}".format(key, value), log)
    print_log("Random Seed: {}".format(args.pretrained_seed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("Torch  version : {}".format(torch.__version__), log)
    print_log("Cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    # Get data specs
    num_classes, (mean, std), input_size, num_channels = get_data_specs(args.pretrained_dataset, args.pretrained_arch)
    pretrained_data_train, pretrained_data_test = get_data(args.pretrained_dataset,
                                                            mean=mean,
                                                            std=std,
                                                            input_size=input_size,
                                                            train_target_model=True)

    pretrained_data_train_loader = torch.utils.data.DataLoader(pretrained_data_train,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    pretrained_data_test_loader = torch.utils.data.DataLoader(pretrained_data_test,
                                                    batch_size=args.batch_size,
                                                    shuffle=False,
                                                    num_workers=args.workers,
                                                    pin_memory=True)


    print_log("=> Creating model '{}'".format(args.pretrained_arch), log)
    # Init model, criterion, and optimizer
    net = get_network(args.pretrained_arch, input_size=input_size, num_classes=num_classes, finetune=args.finetune)
    print_log("=> Network :\n {}".format(net), log)
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    non_trainale_params = get_num_non_trainable_parameters(net)
    trainale_params = get_num_trainable_parameters(net)
    total_params = get_num_parameters(net)
    print_log("Trainable parameters: {}".format(trainale_params), log)
    print_log("Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Total # parameters: {}".format(total_params), log)

    # define loss function (criterion) and optimizer
    criterion_xent = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion_xent.cuda()

    recorder = RecorderMeter(args.epochs)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.epochs):
        current_learning_rate = adjust_learning_rate(args.learning_rate, args.momentum, optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train_target_model(pretrained_data_train_loader, net, criterion_xent, optimizer, epoch, log,
                                    print_freq=args.print_freq,
                                    use_cuda=args.use_cuda)

        # evaluate on validation set
        print_log("Validation on pretrained test dataset:", log)
        val_acc = validate(pretrained_data_test_loader, net, criterion_xent, log, use_cuda=args.use_cuda)
        is_best = recorder.update(epoch, train_los, train_acc, 0., val_acc)

        save_checkpoint({
          'epoch'       : epoch + 1,
          'arch'        : args.pretrained_arch,
          'state_dict'  : net.state_dict(),
          'recorder'    : recorder,
          'optimizer'   : optimizer.state_dict(),
          'args'        : copy.deepcopy(args),
        }, model_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(model_path, 'curve.png') )

    log.close()