Esempio n. 1
0
def main_train_worker(args):
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    print("=> creating model '{}'".format(args.arch))
    network = MetaLearnerModelBuilder.construct_cifar_model(args.arch, args.dataset)
    model_path = '{}/train_pytorch_model/real_image_model/{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
       PY_ROOT, args.dataset, args.arch, args.epochs, args.lr, args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    print("after train, model will be saved to {}".format(model_path))
    network.cuda()
    image_classifier_loss = nn.CrossEntropyLoss().cuda()
    optimizer = RAdam(network.parameters(), args.lr, weight_decay=args.weight_decay)
    cudnn.benchmark = True
    train_loader = DataLoaderMaker.get_img_label_data_loader(args.dataset, args.batch_size, True)
    val_loader = DataLoaderMaker.get_img_label_data_loader(args.dataset, args.batch_size, False)

    for epoch in range(0, args.epochs):
        # adjust_learning_rate(optimizer, epoch, args)
        # train_simulate_grad_mode for one epoch
        train(train_loader, network, image_classifier_loss, optimizer, epoch, args)
        # evaluate_accuracy on validation set
        validate(val_loader, network, image_classifier_loss, args)
        # remember best acc@1 and save checkpoint
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': network.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=model_path)
Esempio n. 2
0
def main_train_worker(args):
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    if args.dataset.startswith("CIFAR"):
        compress_mode = 2
        use_tanh = False
        resize = None
        img_size = 32
    if args.dataset == "ImageNet":
        compress_mode = 3
        use_tanh = True
        resize = 128
        img_size = 299
    elif args.dataset in ["MNIST", "FashionMNIST"]:
        compress_mode = 1
        use_tanh = False
        resize = None
        img_size = 28
    network = Codec(img_size,
                    IN_CHANNELS[args.dataset],
                    compress_mode,
                    resize=resize,
                    use_tanh=use_tanh)
    model_path = '{}/train_pytorch_model/AutoZOOM/AutoEncoder_{}@compress_{}@use_tanh_{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
        PY_ROOT, args.dataset, compress_mode, use_tanh, args.epochs, args.lr,
        args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    print("Model will be saved to {}".format(model_path))
    network.cuda()
    mse_loss_fn = nn.MSELoss().cuda()
    optimizer = RAdam(network.parameters(),
                      args.lr,
                      weight_decay=args.weight_decay)
    cudnn.benchmark = True
    train_loader = DataLoaderMaker.get_img_label_data_loader(
        args.dataset, args.batch_size, True, (img_size, img_size))
    # val_loader = DataLoaderMaker.get_img_label_data_loader(args.dataset, args.batch_size, False)

    for epoch in range(0, args.epochs):
        # adjust_learning_rate(optimizer, epoch, args)
        # train_simulate_grad_mode for one epoch
        train(train_loader, network, mse_loss_fn, optimizer, epoch, args,
              use_tanh)
        # evaluate_accuracy on validation set
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'encoder': network.encoder.state_dict(),
                'decoder': network.decoder.state_dict(),
                "compress_mode": compress_mode,
                "use_tanh": use_tanh,
                'optimizer': optimizer.state_dict(),
            },
            filename=model_path)