Exemplo n.º 1
0
    def checkpoint(self, name=None, timestamp=False, tags=None, verbose=False):

        if name is None:
            name = self.config["name"]

        return save_checkpoint(self.get_state(),
                               name=name, tag=tags, timestamp=timestamp,
                               verbose=verbose)
Exemplo n.º 2
0
def main():
    args = parse_arguments()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.use_cuda:
        torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True

    model_path = get_model_path(args.dataset, args.arch, args.seed)

    # Init logger
    log_file_name = os.path.join(model_path, 'log.txt')
    print("Log file: {}".format(log_file_name))
    log = open(log_file_name, 'w')
    print_log('model path : {}'.format(model_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    for key, value in state.items():
        print_log("{} : {}".format(key, value), log)
    print_log("Random Seed: {}".format(args.seed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("Torch  version : {}".format(torch.__version__), log)
    print_log("Cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Data specifications for the webistes dataset
    mean = [0., 0., 0.]
    std = [1., 1., 1.]
    input_size = 224
    num_classes = 4

    # Dataset
    traindir = os.path.join(WEBSITES_DATASET_PATH, 'train')
    valdir = os.path.join(WEBSITES_DATASET_PATH, 'val')

    train_transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    test_transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    data_train = dset.ImageFolder(root=traindir, transform=train_transform)
    data_test = dset.ImageFolder(root=valdir, transform=test_transform)

    # Dataloader
    data_train_loader = torch.utils.data.DataLoader(data_train,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)
    data_test_loader = torch.utils.data.DataLoader(data_test,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True)

    # Network
    if args.arch == "vgg16":
        net = models.vgg16(pretrained=True)
    elif args.arch == "vgg19":
        net = models.vgg19(pretrained=True)
    elif args.arch == "resnet18":
        net = models.resnet18(pretrained=True)
    elif args.arch == "resnet50":
        net = models.resnet50(pretrained=True)
    elif args.arch == "resnet101":
        net = models.resnet101(pretrained=True)
    elif args.arch == "resnet152":
        net = models.resnet152(pretrained=True)
    else:
        raise ValueError("Network {} not supported".format(args.arch))

    if num_classes != 1000:
        net = manipulate_net_architecture(model_arch=args.arch,
                                          net=net,
                                          num_classes=num_classes)

    # Loss function
    if args.loss_function == "ce":
        criterion = torch.nn.CrossEntropyLoss()
    else:
        raise ValueError

    # Cuda
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    # Optimizer
    momentum = 0.9
    decay = 5e-4
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=args.learning_rate,
                                momentum=momentum,
                                weight_decay=decay,
                                nesterov=True)

    recorder = RecorderMeter(args.epochs)
    start_time = time.time()
    epoch_time = AverageMeter()

    # Main loop
    for epoch in range(args.epochs):
        current_learning_rate = adjust_learning_rate(args.learning_rate,
                                                     momentum, optimizer,
                                                     epoch, args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train_model(data_loader=data_train_loader,
                                           model=net,
                                           criterion=criterion,
                                           optimizer=optimizer,
                                           epoch=epoch,
                                           log=log,
                                           print_freq=200,
                                           use_cuda=True)

        # evaluate on test set
        print_log("Validation on test dataset:", log)
        val_acc, val_loss = validate(data_test_loader,
                                     net,
                                     criterion,
                                     log=log,
                                     use_cuda=args.use_cuda)
        recorder.update(epoch, train_los, train_acc, val_loss, val_acc)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'args': copy.deepcopy(args),
            }, model_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(model_path, 'curve.png'))

    log.close()
Exemplo n.º 3
0
def main():
    args = parse_arguments()

    random.seed(args.pretrained_seed)
    torch.manual_seed(args.pretrained_seed)
    if args.use_cuda:
        torch.cuda.manual_seed_all(args.pretrained_seed)
    cudnn.benchmark = True

    # get the result path to store the results
    result_path = get_result_path(dataset_name=args.dataset,
                                network_arch=args.pretrained_arch,
                                random_seed=args.pretrained_seed,
                                result_subfolder=args.result_subfolder,
                                postfix=args.postfix)

    # Init logger
    log_file_name = os.path.join(result_path, 'log.txt')
    print("Log file: {}".format(log_file_name))
    log = open(log_file_name, 'w')
    print_log('save path : {}'.format(result_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    for key, value in state.items():
        print_log("{} : {}".format(key, value), log)
    print_log("Random Seed: {}".format(args.pretrained_seed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("Torch  version : {}".format(torch.__version__), log)
    print_log("Cudnn  version : {}".format(torch.backends.cudnn.version()), log)

    _, pretrained_data_test = get_data(args.pretrained_dataset, args.pretrained_dataset)

    pretrained_data_test_loader = torch.utils.data.DataLoader(pretrained_data_test,
                                                    batch_size=args.batch_size,
                                                    shuffle=False,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    ##### Dataloader for training ####
    num_classes, (mean, std), input_size, num_channels = get_data_specs(args.pretrained_dataset)

    data_train, _ = get_data(args.dataset, args.pretrained_dataset)
    data_train_loader = torch.utils.data.DataLoader(data_train,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    ####################################
    # Init model, criterion, and optimizer
    print_log("=> Creating model '{}'".format(args.pretrained_arch), log)
    # get a path for loading the model to be attacked
    model_path = get_model_path(dataset_name=args.pretrained_dataset,
                                network_arch=args.pretrained_arch,
                                random_seed=args.pretrained_seed)
    model_weights_path = os.path.join(model_path, "checkpoint.pth.tar")

    target_network = get_network(args.pretrained_arch,
                                input_size=input_size,
                                num_classes=num_classes,
                                finetune=False)

    print_log("=> Network :\n {}".format(target_network), log)
    target_network = torch.nn.DataParallel(target_network, device_ids=list(range(args.ngpu)))
    # Set the target model into evaluation mode
    target_network.eval()
    # Imagenet models use the pretrained pytorch weights
    if args.pretrained_dataset != "imagenet":
        network_data = torch.load(model_weights_path)
        target_network.load_state_dict(network_data['state_dict'])

    # Set all weights to not trainable
    set_parameter_requires_grad(target_network, requires_grad=False)

    non_trainale_params = get_num_non_trainable_parameters(target_network)
    trainale_params = get_num_trainable_parameters(target_network)
    total_params = get_num_parameters(target_network)
    print_log("Target Network Trainable parameters: {}".format(trainale_params), log)
    print_log("Target Network Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Target Network Total # parameters: {}".format(total_params), log)

    print_log("=> Inserting Generator", log)

    generator = UAP(shape=(input_size, input_size),
                num_channels=num_channels,
                mean=mean,
                std=std,
                use_cuda=args.use_cuda)

    print_log("=> Generator :\n {}".format(generator), log)
    non_trainale_params = get_num_non_trainable_parameters(generator)
    trainale_params = get_num_trainable_parameters(generator)
    total_params = get_num_parameters(generator)
    print_log("Generator Trainable parameters: {}".format(trainale_params), log)
    print_log("Generator Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Generator Total # parameters: {}".format(total_params), log)

    perturbed_net = nn.Sequential(OrderedDict([('generator', generator), ('target_model', target_network)]))
    perturbed_net = torch.nn.DataParallel(perturbed_net, device_ids=list(range(args.ngpu)))

    non_trainale_params = get_num_non_trainable_parameters(perturbed_net)
    trainale_params = get_num_trainable_parameters(perturbed_net)
    total_params = get_num_parameters(perturbed_net)
    print_log("Perturbed Net Trainable parameters: {}".format(trainale_params), log)
    print_log("Perturbed Net Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Perturbed Net Total # parameters: {}".format(total_params), log)

    # Set the target model into evaluation mode
    perturbed_net.module.target_model.eval()
    perturbed_net.module.generator.train()

    if args.loss_function == "ce":
        criterion = torch.nn.CrossEntropyLoss()
    elif args.loss_function == "neg_ce":
        criterion = NegativeCrossEntropy()
    elif args.loss_function == "logit":
        criterion = LogitLoss(num_classes=num_classes, use_cuda=args.use_cuda)
    elif args.loss_function == "bounded_logit":
        criterion = BoundedLogitLoss(num_classes=num_classes, confidence=args.confidence, use_cuda=args.use_cuda)
    elif args.loss_function == "bounded_logit_fixed_ref":
        criterion = BoundedLogitLossFixedRef(num_classes=num_classes, confidence=args.confidence, use_cuda=args.use_cuda)
    elif args.loss_function == "bounded_logit_neg":
        criterion = BoundedLogitLoss_neg(num_classes=num_classes, confidence=args.confidence, use_cuda=args.use_cuda)
    else:
        raise ValueError

    if args.use_cuda:
        target_network.cuda()
        generator.cuda()
        perturbed_net.cuda()
        criterion.cuda()

    optimizer = torch.optim.Adam(perturbed_net.parameters(), lr=state['learning_rate'])
    
    # Measure the time needed for the UAP generation
    start = time.time()
    train(data_loader=data_train_loader,
            model=perturbed_net,
            criterion=criterion,
            optimizer=optimizer,
            epsilon=args.epsilon,
            num_iterations=args.num_iterations,
            targeted=args.targeted,
            target_class=args.target_class,
            log=log,
            print_freq=args.print_freq,
            use_cuda=args.use_cuda)
    end = time.time()
    print_log("Time needed for UAP generation: {}".format(end - start), log)
    # evaluate
    print_log("Final evaluation:", log)
    metrics_evaluate(data_loader=pretrained_data_test_loader,
                    target_model=target_network,
                    perturbed_model=perturbed_net,
                    targeted=args.targeted,
                    target_class=args.target_class,
                    log=log,
                    use_cuda=args.use_cuda)

    save_checkpoint({
      'arch'        : args.pretrained_arch,
      # 'state_dict'  : perturbed_net.state_dict(),
      'state_dict'  : perturbed_net.module.generator.state_dict(),
      'optimizer'   : optimizer.state_dict(),
      'args'        : copy.deepcopy(args),
    }, result_path, 'checkpoint.pth.tar')

    log.close()
Exemplo n.º 4
0
Arquivo: train.py Projeto: stry/DBSN
        test_loss, test_err, test_iou = train_utils.test(
            model, loaders['test'], criterion, alphas, betas, biOptimizer,
            ngpus, dbsn)
        logging.info('Test - Loss: {:.4f} | Acc: {:.4f} | IOU: {:.4f}'.format(
            test_loss, 1 - test_err, test_iou))

    # time_elapsed = time.time() - since
    # logging.info('Total Time {:.0f}m {:.0f}s'.format(
    #     time_elapsed // 60, time_elapsed % 60))

    ### Checkpoint ###
    if epoch % args.save_freq is 0 or epoch == args.epochs:
        logging.info('Saving model at Epoch: {}'.format(epoch))
        train_utils.save_checkpoint(dir=args.dir,
                                    epoch=epoch,
                                    state_dict=model.state_dict(),
                                    optimizer=optimizer.state_dict(),
                                    alphas=alphas,
                                    betas=betas)

    if args.optimizer == 'RMSProp':
        ### Adjust Lr ###
        if epoch < args.ft_start:
            scheduler.step(epoch=epoch)
        else:
            #scheduler.step(epoch=-1) #reset to args.lr_init for fine-tuning
            train_utils.adjust_learning_rate(optimizer, args.ft_lr)

    elif args.optimizer == 'SGD':
        lr = train_utils.schedule(epoch, args.lr_init, args.epochs)
        train_utils.adjust_learning_rate(optimizer, lr)
            "Val - Loss: {:.4f} | Acc: {:.4f} | IOU: {:.4f}".format(
                val_loss, 1 - val_err, val_iou
            )
        )
        writer.add_scalar("val/loss", val_loss, epoch)
        writer.add_scalar("val/error", val_err, epoch)

    time_elapsed = time.time() - since
    print("Total Time {:.0f}m {:.0f}s\n".format(time_elapsed // 60, time_elapsed % 60))

    ### Checkpoint ###
    if epoch % args.save_freq == 0:
        print("Saving model at Epoch: ", epoch)
        save_checkpoint(
            dir=args.dir,
            epoch=epoch,
            state_dict=model.state_dict(),
            optimizer=optimizer.state_dict(),
        )

    lr = schedule(
        epoch, args.lr_init, args.epochs
    )
    adjust_learning_rate(optimizer, lr)
    writer.add_scalar("hypers/lr", lr, epoch)

### Test set ###

test_loss, test_err, test_iou = train_utils.test(model, loaders["test"], criterion)
print(
    "SGD Test - Loss: {:.4f} | Acc: {:.4f} | IOU: {:.4f}".format(
        test_loss, 1 - test_err, test_iou
def main():
    args = parse_arguments()

    random.seed(args.pretrained_seed)
    torch.manual_seed(args.pretrained_seed)
    if args.use_cuda:
        torch.cuda.manual_seed_all(args.pretrained_seed)
    cudnn.benchmark = True

    # Get data specs
    num_classes, (mean, std), input_size, num_channels = get_data_specs(args.pretrained_dataset, args.pretrained_arch)

    # Construct the array other classes:
    if args.pretrained_dataset in ["imagenet", "ycb"]:
        other_classes = args.source_classes
    else:
        all_classes = np.arange(num_classes)
        other_classes =  [int(cl) for cl in all_classes if cl not in args.source_classes]

    half_batch_size = args.batch_size//2

    # get the result path to store the results
    result_path = get_result_path(dataset_name=args.pretrained_dataset,
                                network_arch=args.pretrained_arch,
                                random_seed=args.pretrained_seed,
                                result_subfolder=args.result_subfolder,
                                source_class=args.source_classes,
                                sink_class=args.sink_classes,
                                postfix=args.postfix)

    # Init logger
    log_file_name = os.path.join(result_path, 'log.txt')
    print("Log file: {}".format(log_file_name))
    log = open(log_file_name, 'w')
    print_log('save path : {}'.format(result_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    for key, value in state.items():
        print_log("{} : {}".format(key, value), log)
    print_log("Random Seed: {}".format(args.pretrained_seed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("Torch  version : {}".format(torch.__version__), log)
    print_log("Cudnn  version : {}".format(torch.backends.cudnn.version()), log)

    data_train_sources, data_test_sources = get_data(args.pretrained_dataset,
                                                    mean=mean,
                                                    std=std,
                                                    input_size=input_size,
                                                    classes=args.source_classes,
                                                    train_samples_per_class=args.num_train_samples_per_class)
    data_train_sources_loader = torch.utils.data.DataLoader(data_train_sources,
                                                    batch_size=half_batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    data_test_sources_loader = torch.utils.data.DataLoader(data_test_sources,
                                                    batch_size=half_batch_size,
                                                    shuffle=False,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    data_train_others, data_test_others = get_data(args.pretrained_dataset,
                                                    mean=mean,
                                                    std=std,
                                                    input_size=input_size,
                                                    classes=other_classes,
                                                    others=True,
                                                    train_samples_per_class=args.num_train_samples_per_class)
    data_train_others_loader = torch.utils.data.DataLoader(data_train_others,
                                                    batch_size=half_batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    data_test_others_loader = torch.utils.data.DataLoader(data_test_others,
                                                    batch_size=half_batch_size,
                                                    shuffle=False,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    # Init model, criterion, and optimizer
    print_log("=> Creating model '{}'".format(args.pretrained_arch), log)
    # get a path for loading the model to be attacked
    model_path = get_model_path(dataset_name=args.pretrained_dataset,
                                network_arch=args.pretrained_arch,
                                random_seed=args.pretrained_seed)
    model_weights_path = os.path.join(model_path, "checkpoint.pth.tar")

    target_network = get_network(args.pretrained_arch, input_size=input_size, num_classes=num_classes, finetune=args.finetune)
    # print_log("=> Network :\n {}".format(target_network), log)
    target_network = torch.nn.DataParallel(target_network, device_ids=list(range(args.ngpu)))
    # Set the target model into evaluation mode
    target_network.eval()
    # Imagenet models use the pretrained pytorch weights
    if args.pretrained_dataset != "imagenet":
        network_data = torch.load(model_weights_path)
        target_network.load_state_dict(network_data['state_dict'])

    # Set all weights to not trainable
    set_parameter_requires_grad(target_network, requires_grad=False)

    non_trainale_params = get_num_non_trainable_parameters(target_network)
    trainale_params = get_num_trainable_parameters(target_network)
    total_params = get_num_parameters(target_network)
    print_log("Target Network Trainable parameters: {}".format(trainale_params), log)
    print_log("Target Network Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Target Network Total # parameters: {}".format(total_params), log)

    print_log("=> Inserting Generator", log)
    generator = UAP(shape=(input_size, input_size),
                    num_channels=num_channels,
                    mean=mean,
                    std=std,
                    use_cuda=args.use_cuda)

    print_log("=> Generator :\n {}".format(generator), log)
    non_trainale_params = get_num_non_trainable_parameters(generator)
    trainale_params = get_num_trainable_parameters(generator)
    total_params = get_num_parameters(generator)
    print_log("Generator Trainable parameters: {}".format(trainale_params), log)
    print_log("Generator Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Generator Total # parameters: {}".format(total_params), log)

    perturbed_net = nn.Sequential(OrderedDict([('generator', generator), ('target_model', target_network)]))
    perturbed_net = torch.nn.DataParallel(perturbed_net, device_ids=list(range(args.ngpu)))

    non_trainale_params = get_num_non_trainable_parameters(perturbed_net)
    trainale_params = get_num_trainable_parameters(perturbed_net)
    total_params = get_num_parameters(perturbed_net)
    print_log("Perturbed Net Trainable parameters: {}".format(trainale_params), log)
    print_log("Perturbed Net Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Perturbed Net Total # parameters: {}".format(total_params), log)

    # Set the target model into evaluation mode
    perturbed_net.module.target_model.eval()
    perturbed_net.module.generator.train()

    criterion = LossConstructor(source_classes=args.source_classes,
                                sink_classes=args.sink_classes,
                                num_classes=num_classes,
                                source_loss=args.source_loss,
                                others_loss=args.others_loss,
                                confidence=args.confidence,
                                alpha=args.alpha,
                                use_cuda=args.use_cuda)

    if args.use_cuda:
        target_network.cuda()
        generator.cuda()
        perturbed_net.cuda()
        criterion.cuda()

    optimizer = torch.optim.Adam(perturbed_net.parameters(),
                                    lr=state['learning_rate']) # betas=(0.5, 0.999)

    if args.pretrained_dataset not in ["imagenet", "ycb"]:
        metrics_evaluate(source_loader=data_test_sources_loader,
                                others_loader=data_test_others_loader,
                                target_model=target_network,
                                perturbed_model=perturbed_net,
                                source_classes=args.source_classes,
                                sink_classes=args.sink_classes,
                                log=log,
                                use_cuda=args.use_cuda)

    start_time = time.time()
    train_half_half(sources_data_loader=data_train_sources_loader,
                    others_data_loader=data_train_others_loader,
                    model=perturbed_net,
                    target_model=target_network,
                    criterion=criterion,
                    optimizer=optimizer,
                    epsilon=args.epsilon,
                    num_iterations=args.num_iterations,
                    log=log,
                    print_freq=args.print_freq,
                    use_cuda=args.use_cuda)
    end_time = time.time()
    print_log("Elapsed generation time: {}".format(end_time-start_time), log)

    # evaluate
    print_log("Final evaluation:", log)
    metrics_evaluate(source_loader=data_test_sources_loader,
                    others_loader=data_test_others_loader,
                    target_model=target_network,
                    perturbed_model=perturbed_net,
                    source_classes=args.source_classes,
                    sink_classes=args.sink_classes,
                    log=log,
                    use_cuda=args.use_cuda)

    save_checkpoint({
      'arch'        : args.pretrained_arch,
      'state_dict'  : perturbed_net.state_dict(),
      'optimizer'   : optimizer.state_dict(),
      'args'        : copy.deepcopy(args),
    }, result_path, 'checkpoint.pth.tar')

    # Plot the adversarial perturbation
    uap_numpy = perturbed_net.module.generator.uap.detach().cpu().numpy()
    # Calculate the norm
    uap_norm = np.linalg.norm(uap_numpy.reshape(-1), np.inf)
    print_log("Norm of UAP: {}".format(uap_norm), log)

    log.close()
Exemplo n.º 7
0
    experiment.metrics["loss_lm_" + name].append(tag="train", value=avg_loss)
    experiment.metrics["ppl_lm_" + name].append(tag="train",
                                                value=math.exp(avg_loss))

    experiment.metrics["loss_lm_" + name].append(tag="val", value=avg_val_loss)
    experiment.metrics["ppl_lm_" + name].append(tag="val",
                                                value=math.exp(avg_val_loss))

    ############################################################
    # epoch summary
    ############################################################
    epoch_summary("train", avg_loss)
    epoch_summary("val", avg_val_loss)

    # after updating all the values, refresh the plots
    experiment.update_plots()

    # Save the model if the validation loss is the best we've seen so far.
    if not best_loss or avg_val_loss < best_loss:
        print("saving checkpoint...")
        save_checkpoint("{}".format(name),
                        model,
                        optimizer,
                        train_set.vocab,
                        loss=avg_val_loss,
                        timestamp=True)
        best_loss = avg_val_loss

    print()
Exemplo n.º 8
0
def main():
    args = parse_arguments()

    random.seed(args.pretrained_seed)
    torch.manual_seed(args.pretrained_seed)
    if args.use_cuda:
        torch.cuda.manual_seed_all(args.pretrained_seed)
    cudnn.benchmark = True

    # get a path for saving the model to be trained
    model_path = get_model_path(dataset_name=args.pretrained_dataset,
                                network_arch=args.pretrained_arch,
                                random_seed=args.pretrained_seed)

    # Init logger
    log_file_name = os.path.join(model_path, 'log_seed_{}.txt'.format(args.pretrained_seed))
    print("Log file: {}".format(log_file_name))
    log = open(log_file_name, 'w')
    print_log('save path : {}'.format(model_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    for key, value in state.items():
        print_log("{} : {}".format(key, value), log)
    print_log("Random Seed: {}".format(args.pretrained_seed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("Torch  version : {}".format(torch.__version__), log)
    print_log("Cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    # Get data specs
    num_classes, (mean, std), input_size, num_channels = get_data_specs(args.pretrained_dataset, args.pretrained_arch)
    pretrained_data_train, pretrained_data_test = get_data(args.pretrained_dataset,
                                                            mean=mean,
                                                            std=std,
                                                            input_size=input_size,
                                                            train_target_model=True)

    pretrained_data_train_loader = torch.utils.data.DataLoader(pretrained_data_train,
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    num_workers=args.workers,
                                                    pin_memory=True)

    pretrained_data_test_loader = torch.utils.data.DataLoader(pretrained_data_test,
                                                    batch_size=args.batch_size,
                                                    shuffle=False,
                                                    num_workers=args.workers,
                                                    pin_memory=True)


    print_log("=> Creating model '{}'".format(args.pretrained_arch), log)
    # Init model, criterion, and optimizer
    net = get_network(args.pretrained_arch, input_size=input_size, num_classes=num_classes, finetune=args.finetune)
    print_log("=> Network :\n {}".format(net), log)
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    non_trainale_params = get_num_non_trainable_parameters(net)
    trainale_params = get_num_trainable_parameters(net)
    total_params = get_num_parameters(net)
    print_log("Trainable parameters: {}".format(trainale_params), log)
    print_log("Non Trainable parameters: {}".format(non_trainale_params), log)
    print_log("Total # parameters: {}".format(total_params), log)

    # define loss function (criterion) and optimizer
    criterion_xent = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion_xent.cuda()

    recorder = RecorderMeter(args.epochs)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.epochs):
        current_learning_rate = adjust_learning_rate(args.learning_rate, args.momentum, optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train_target_model(pretrained_data_train_loader, net, criterion_xent, optimizer, epoch, log,
                                    print_freq=args.print_freq,
                                    use_cuda=args.use_cuda)

        # evaluate on validation set
        print_log("Validation on pretrained test dataset:", log)
        val_acc = validate(pretrained_data_test_loader, net, criterion_xent, log, use_cuda=args.use_cuda)
        is_best = recorder.update(epoch, train_los, train_acc, 0., val_acc)

        save_checkpoint({
          'epoch'       : epoch + 1,
          'arch'        : args.pretrained_arch,
          'state_dict'  : net.state_dict(),
          'recorder'    : recorder,
          'optimizer'   : optimizer.state_dict(),
          'args'        : copy.deepcopy(args),
        }, model_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(model_path, 'curve.png') )

    log.close()
    def train(self):
        # 训练的每个epoch,start_epoch是根据断点来的,在没有断点的时候就是0
        for epoch in range(self.start_epoch, self.args.epochs):

            # optimizer.param_groups[0]:长度为6的字典,包括[‘amsgrad’, ‘params’, ‘lr’, ‘betas’, ‘weight_decay’, ‘eps’]这6个参数
            # 这里不过是获取一下学习率打印出来罢了
            lr = self.g_optimizer.param_groups[0]['lr']
            print('learning rate = %.7f' % lr)

            for i, (a_real,
                    b_real) in enumerate(zip(self.a_loader, self.b_loader)):
                # Generator Computations
                training.set_grad([self.Da, self.Db],
                                  False)  # 固定D训练G,猜测D的loss和G无关所以不需要锁定G
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0])
                b_real = Variable(b_real[0])
                a_real, b_real = network.cuda([a_real, b_real])

                # Forward pass through generators
                a_fake = self.Gab(b_real)  # 从b的真图片生成假的a
                b_fake = self.Gba(a_real)  # 从a的真图片生成假的b

                a_recon = self.Gab(b_fake)  # 从b的假图片重构a
                b_recon = self.Gba(a_fake)  # 从a的假图片重构b

                a_idt = self.Gab(a_real)  # 注意a_idt是a_real喂进去得到的,本来Gab应该是从b转换到a
                b_idt = self.Gba(b_real)

                # Identity losses
                # 这个loss似乎在论文中没有提及,不过也是常用的,意义在于希望转换后基本上保证画面内要素不变
                a_idt_loss = self.L1(
                    a_idt, a_real) * self.args.lamda * self.args.idt_coef
                b_idt_loss = self.L1(
                    b_idt, b_real) * self.args.lamda * self.args.idt_coef

                # Adversarial losses
                # 对抗loss,从这里可以推断出D输出的是图片为真的概率(注意是G的视角)
                a_fake_dis = self.Da(a_fake)
                b_fake_dis = self.Db(b_fake)

                # 这些假图片判为真的概率应为1
                real_label = network.cuda(
                    Variable(torch.ones(a_fake_dis.size())))

                # 按照论文的说法计算MSE作为G的对抗loss
                a_gen_loss = self.MSE(a_fake_dis, real_label)
                b_gen_loss = self.MSE(b_fake_dis, real_label)

                # 循环不变loss
                a_cycle_loss = self.L1(a_recon, a_real) * self.args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * self.args.lamda

                # 总的生成loss
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # loss进行反向传播,优化器更新G的参数
                gen_loss.backward()
                self.g_optimizer.step()

                # 开始训练D
                training.set_grad([self.Da, self.Db], True)  # 解锁D训练D
                self.d_optimizer.zero_grad()

                # buffer对象是可以call的,就是将新生成的一波假图到buffer中涮一下拿出一些给D更新
                a_fake = Variable(
                    torch.Tensor(
                        self.a_fake_sample([a_fake.cpu().data.numpy()])[0]))
                b_fake = Variable(
                    torch.Tensor(
                        self.b_fake_sample([b_fake.cpu().data.numpy()])[0]))
                a_fake, b_fake = network.cuda([a_fake, b_fake])

                # 让两个D分别对真实样本和生成样本进行判断
                a_real_dis = self.Da(a_real)
                a_fake_dis = self.Da(a_fake)
                b_real_dis = self.Db(b_real)
                b_fake_dis = self.Db(b_fake)
                real_label = network.cuda(
                    Variable(torch.ones(a_real_dis.size())))
                fake_label = network.cuda(
                    Variable(torch.zeros(a_fake_dis.size())))

                # 这些loss都和标签都用MSE
                a_dis_real_loss = self.MSE(a_real_dis, real_label)
                a_dis_fake_loss = self.MSE(a_fake_dis, fake_label)
                b_dis_real_loss = self.MSE(b_real_dis, real_label)
                b_dis_fake_loss = self.MSE(b_fake_dis, fake_label)

                # Total discriminators losses
                a_dis_loss = (a_dis_real_loss + a_dis_fake_loss) * 0.5
                b_dis_loss = (b_dis_real_loss + b_dis_fake_loss) * 0.5

                # loss进行反向传播,优化器更新D的参数
                a_dis_loss.backward()
                b_dis_loss.backward()
                self.d_optimizer.step()

                print(
                    "Epoch: (%3d) (%5d/%5d) | Gen Loss:%.2e | Dis Loss:%.2e" %
                    (epoch, i + 1, min(len(self.a_loader), len(
                        self.b_loader)), gen_loss, a_dis_loss + b_dis_loss))

            # 保存最新的断点
            training.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'Da': self.Da.state_dict(),
                    'Db': self.Db.state_dict(),
                    'Gab': self.Gab.state_dict(),
                    'Gba': self.Gba.state_dict(),
                    'd_optimizer': self.d_optimizer.state_dict(),
                    'g_optimizer': self.g_optimizer.state_dict()
                }, '%s/latest.ckpt' % self.args.checkpoint_dir)

            # 更新学习率
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()