os.makedirs(save_ckpt_dir)
if not os.path.exists(save_log_dir):
    os.makedirs(save_log_dir)

# 参数设置
param = {}

param['epochs'] = 30          # 训练轮数
param['batch_size'] = 16       # 批大小
param['lr'] = 1e-2            # 学习率
param['gamma'] = 0.2          # 学习率衰减系数
param['step_size'] = 5        # 学习率衰减间隔
param['momentum'] = 0.9       # 动量
param['weight_decay'] = 5e-4    # 权重衰减
param['disp_inter'] = 1       # 显示间隔(epoch)
param['save_inter'] = 4       # 保存间隔(epoch)
param['iter_inter'] = 50     # 显示迭代间隔(batch)
param['min_inter'] = 10

param['model_name'] = model_name          # 模型名称
param['save_log_dir'] = save_log_dir      # 日志保存路径
param['save_ckpt_dir'] = save_ckpt_dir    # 权重保存路径

# 加载权重路径(继续训练)
param['load_ckpt_dir'] = None

#
# 训练
best_model, model = train_net(param, model, train_data, valid_data)

Beispiel #2
0
        net = googlelenet.load(len(class_names))
        base_lr = 0.1
        resize = (96, 96)
    elif (net_name == "resnet-N" or net_name == "resnet"):
        net = resnet.load('resnet-N', len(class_names))
    elif net_name == "resnet-164":
        net = resnet.load('resnet-164', len(class_names))
    elif net_name == "densenet":
        net = densenet.load(len(class_names))
    elif net_name == "squeezenet":
        net = squeezenet.load(len(class_names), (96, 3, 1, 1))
        base_lr = 0.01  #must be with small lr

# print(output_prefix + '.params')
    if os.path.exists(output_prefix + '.params'):
        net.load_parameters(output_prefix + '.params')
        print('finetune based on ', output_prefix, '.params')

    #for param in net.collect_params():
    #    if net.collect_params()[param].grad_req != "null":
    #        pp = net.collect_params()[param].grad()
    #print(net.collect_params()[param].grad.values())
    # sw.add_histogram(tag=key, values=value.grads(), global_step=iter_num, bins=1000)

    net.collect_params().reset_ctx(ctx)
    trainer = configs[config_name].trainer(net)
    lr_sch = configs[config_name].learning_rate_scheduler()

    train_net(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs,
              lr_sch, output_prefix)
Beispiel #3
0
def main():
    # train args
    parser = argparse.ArgumentParser(
        description='Disributional Sliced Wasserstein Autoencoder')
    parser.add_argument('--datadir', default='./', help='path to dataset')
    parser.add_argument('--outdir',
                        default='./result',
                        help='directory to output images')
    parser.add_argument('--batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for training (default: 512)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument(
        '--num-workers',
        type=int,
        default=16,
        metavar='N',
        help='number of dataloader workers if device is CPU (default: 16)')
    parser.add_argument('--seed',
                        type=int,
                        default=16,
                        metavar='S',
                        help='random seed (default: 16)')
    parser.add_argument('--g', type=str, default='circular', help='g')
    parser.add_argument('--num-projection',
                        type=int,
                        default=1000,
                        help='number projection')
    parser.add_argument('--lam',
                        type=float,
                        default=1,
                        help='Regularization strength')
    parser.add_argument('--p', type=int, default=2, help='Norm p')
    parser.add_argument('--niter',
                        type=int,
                        default=10,
                        help='number of iterations')
    parser.add_argument('--r', type=float, default=1000, help='R')
    parser.add_argument('--latent-size',
                        type=int,
                        default=32,
                        help='Latent size')
    parser.add_argument('--dataset',
                        type=str,
                        default='MNIST',
                        help='(CELEBA|CIFAR)')
    parser.add_argument('--model-type',
                        type=str,
                        required=True,
                        help='(SWD|MSWD|DSWD|GSWD|DGSWD|CRAMER|)')
    args = parser.parse_args()
    torch.random.manual_seed(args.seed)
    if (args.g == 'circular'):
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    model_dir = os.path.join(args.outdir, model_type)
    assert dataset in ['CELEBA', 'CIFAR']
    assert model_type in ['SWD', 'MSWD', 'DSWD', 'GSWD', 'DGSWD', 'CRAMER']
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format(
        args.batch_size, args.epochs, args.lr, device.type))

    if (dataset == 'CIFAR'):
        from DCGANAE import Discriminator
        image_size = 64
        num_chanel = 3
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                args.datadir,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(64),
                    transforms.ToTensor(),
                    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers)

    elif (dataset == 'CELEBA'):
        from DCGANAE import Discriminator
        image_size = 64
        num_chanel = 3
        dataset = CustomDataset(
            root=args.datadir + '/img_align_celeba',
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))
        # Create the dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)

    model = DCGANAE(image_size=64,
                    latent_size=latent_size,
                    num_chanel=3,
                    hidden_chanels=64,
                    device=device).to(device)
    dis = Discriminator(64, args.latent_size, 3, 64).to(device)
    disoptimizer = optim.Adam(dis.parameters(), lr=args.lr, betas=(0.5, 0.999))
    if (model_type == 'DSWD' or model_type == 'DGSWD'):
        transform_net = TransformNet(64 * 8 * 4 * 4).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=args.lr,
                                betas=(0.5, 0.999))
        train_net(64 * 8 * 4 * 4, 1000, transform_net, op_trannet)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    epoch_cont = 0

    fixednoise = torch.randn((64, latent_size)).to(device)

    for epoch in range(epoch_cont, args.epochs):
        total_loss = 0.0

        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            if (model_type == 'SWD'):
                loss = model.compute_loss_SWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p)
            elif (model_type == 'GSWD'):
                loss = model.compute_loss_GSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=args.p)
            elif (model_type == 'MSWD'):
                loss, v = model.compute_loss_MSWD(dis,
                                                  disoptimizer,
                                                  data,
                                                  torch.randn,
                                                  p=args.p,
                                                  max_iter=args.niter)
            elif (model_type == 'DSWD'):
                loss = model.compute_lossDSWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              transform_net,
                                              op_trannet,
                                              p=args.p,
                                              max_iter=args.niter,
                                              lam=args.lam)
            elif (model_type == 'DGSWD'):
                loss = model.compute_lossDGSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               g_function,
                                               args.r,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=args.lam)
            elif (model_type == 'CRAMER'):
                loss = model.compute_loss_cramer(dis, disoptimizer, data,
                                                 torch.randn)

            optimizer.zero_grad()
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

        total_loss /= (batch_idx + 1)
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if (epoch % 1 == 0 or epoch == args.epochs - 1):
            sampling(model_dir + '/sample_epoch_' + str(epoch) + ".png",
                     fixednoise, model.decoder, 64, image_size, num_chanel)
Beispiel #4
0
def main():
    # train args
    parser = argparse.ArgumentParser(
        description='Augmented Sliced Wasserstein Autoencoder')
    parser.add_argument('--datadir', default='./', help='path to dataset')
    parser.add_argument('--outdir',
                        default='./result/',
                        help='directory to output images')
    parser.add_argument('--batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for training (default: 512)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0005,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument(
        '--num-workers',
        type=int,
        default=16,
        metavar='N',
        help='number of dataloader workers if device is CPU (default: 16)')
    parser.add_argument('--seed',
                        type=int,
                        default=11,
                        metavar='S',
                        help='random seed (default: 16)')
    parser.add_argument('--g', type=str, default='circular', help='g')
    parser.add_argument('--num-projection',
                        type=int,
                        default=1000,
                        help='number projection')
    parser.add_argument('--lam',
                        type=float,
                        default=0.5,
                        help='Regularization strength')
    parser.add_argument('--p', type=int, default=2, help='Norm p')
    parser.add_argument('--niter',
                        type=int,
                        default=5,
                        help='number of iterations')
    parser.add_argument('--r', type=float, default=1000, help='R')
    parser.add_argument('--latent-size',
                        type=int,
                        default=32,
                        help='Latent size')
    parser.add_argument('--dataset',
                        type=str,
                        default='CIFAR',
                        help='(CELEBA|CIFAR)')
    parser.add_argument('--model-type',
                        type=str,
                        required=True,
                        help='(ASWD|SWD|MSWD|DSWD|GSWD|)')
    parser.add_argument('--gpu', type=str, required=False, default=0)
    args = parser.parse_args()
    torch.random.manual_seed(args.seed)
    if (args.g == 'circular'):
        g_function = circular_function
    model_type = args.model_type
    latent_size = args.latent_size
    num_projection = args.num_projection
    dataset = args.dataset
    model_dir = os.path.join(args.outdir, model_type)
    assert dataset in ['CELEBA', 'CIFAR']
    assert model_type in ['ASWD', 'SWD', 'MSWD', 'DSWD', 'GSWD']
    if not (os.path.isdir(args.datadir)):
        os.makedirs(args.datadir)
    if not (os.path.isdir(args.outdir)):
        os.makedirs(args.outdir)
    if not (os.path.isdir(model_dir)):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:" + str(args.gpu) if use_cuda else "cpu")
    print('batch size {}\nepochs {}\nAdam lr {} \n using device {}\n'.format(
        args.batch_size, args.epochs, args.lr, device.type))

    if (dataset == 'CIFAR'):
        from DCGANAE import Discriminator
        image_size = 64
        num_chanel = 3
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(
                args.datadir,
                train=True,
                download=True,
                transform=transforms.Compose([
                    transforms.Resize(64),
                    transforms.ToTensor(),
                    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers)

    elif (dataset == 'CELEBA'):
        from DCGANAE import Discriminator
        image_size = 64
        num_chanel = 3
        dataset = CustomDataset(
            root=args.datadir + 'img_align_celeba',
            transform=transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))
        # Create the dataloader
        train_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)

    model = DCGANAE(image_size=64,
                    latent_size=latent_size,
                    num_chanel=3,
                    hidden_chanels=64,
                    device=device).to(device)
    #model=nn.DataParallel(model)
    #model.to(device)
    dis = Discriminator(64, args.latent_size, 3, 64).to(device)
    disoptimizer = optim.Adam(dis.parameters(), lr=args.lr, betas=(0.5, 0.999))
    if (model_type == 'DSWD' or model_type == 'DGSWD'):
        transform_net = TransformNet(64 * 8 * 4 * 4).to(device)
        op_trannet = optim.Adam(transform_net.parameters(),
                                lr=0.0005,
                                betas=(0.5, 0.999))
        train_net(64 * 8 * 4 * 4,
                  1000,
                  transform_net,
                  op_trannet,
                  device=device)

    if model_type == 'ASWD':
        phi = Mapping(64 * 8 * 4 * 4).to(device)
        phi_op = optim.Adam(phi.parameters(), lr=0.001, betas=(0.5, 0.999))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
    epoch_cont = 0
    generated_sample_number = 64
    fixednoise = torch.randn((generated_sample_number, latent_size)).to(device)
    loss_recorder = []
    generated_sample_size = 64
    W2_recorder = np.zeros([41, 40])
    save_idx = str(time.time()).split('.')
    save_idx = save_idx[0] + save_idx[1]
    path_0 = model_dir + '/' + args.dataset + '/fid/' + save_idx
    os.mkdir(path_0)
    if args.dataset == 'CIFAR':
        interval_ = 10
        fid_stats_file = args.datadir + '/fid_stats_cifar10_train.npz'
    else:
        interval_ = 5
        fid_stats_file = args.datadir + '/fid_stats_celeba.npz'
    fid_recorder = np.zeros(args.epochs // interval_ + 1)
    for epoch in range(epoch_cont, args.epochs):
        total_loss = 0.0

        for batch_idx, (data, y) in tqdm(enumerate(train_loader, start=0)):
            if (model_type == 'SWD'):
                loss = model.compute_loss_SWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              p=args.p,
                                              epoch=epoch,
                                              batch_idx=batch_idx)
            elif (model_type == 'GSWD'):
                loss = model.compute_loss_GSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               g_function,
                                               args.r,
                                               num_projection,
                                               p=args.p)
            elif (model_type == 'MSWD'):
                loss, v = model.compute_loss_MSWD(dis,
                                                  disoptimizer,
                                                  data,
                                                  torch.randn,
                                                  p=args.p,
                                                  max_iter=args.niter)
            elif (model_type == 'DSWD'):
                loss = model.compute_lossDSWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              transform_net,
                                              op_trannet,
                                              p=args.p,
                                              max_iter=args.niter,
                                              lam=10)
            elif (model_type == 'DGSWD'):
                loss = model.compute_lossDGSWD(dis,
                                               disoptimizer,
                                               data,
                                               torch.randn,
                                               num_projection,
                                               transform_net,
                                               op_trannet,
                                               g_function,
                                               args.r,
                                               p=args.p,
                                               max_iter=args.niter,
                                               lam=10)
            elif (model_type == 'SINKHORN'):
                loss = model.compute_loss_sinkhorn(dis,
                                                   disoptimizer,
                                                   data,
                                                   torch.randn,
                                                   p=2,
                                                   n_iter=100,
                                                   e=100)
            elif (model_type == 'CRAMER'):
                loss = model.compute_loss_cramer(dis, disoptimizer, data,
                                                 torch.randn)
            elif model_type == 'ASWD':
                loss = model.compute_lossASWD(dis,
                                              disoptimizer,
                                              data,
                                              torch.randn,
                                              num_projection,
                                              phi,
                                              phi_op,
                                              p=2,
                                              max_iter=args.niter,
                                              lam=args.lam,
                                              epoch=epoch,
                                              batch_idx=batch_idx)
            optimizer.zero_grad()
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        if epoch == 0 or (epoch + 1) % interval_ == 0:
            path_1 = path_0 + '/' + str('%03d' % epoch)
            os.mkdir(path_1)
            for j in range(20):
                fixednoise_ = torch.randn((1000, 32)).to(device)
                imgs = model.decoder(fixednoise_)

                for i, img in enumerate(imgs):
                    img = img.transpose(0, -1).transpose(
                        0, 1).cpu().detach().numpy()
                    img = (img * 255).astype(np.uint8)
                    imageio.imwrite(
                        path_1 + '/' + args.model_type + '_' +
                        str(args.num_projection) + '_' + str('%03d' % epoch) +
                        '_' + str(1000 * j + i) + '.png', img)
            fid_value = calculate_fid_given_paths(
                [path_1 + '/', fid_stats_file], 50, True, 2048)
            fid_recorder[(epoch + 1) // interval_] = fid_value
            np.save(
                path_0 + '/fid_recorder_' + 'np_' + str(num_projection) +
                '.npy', fid_recorder)
            print('fid score:', fid_value)
            os.system("rm -rf " + path_1)
        total_loss /= (batch_idx + 1)
        loss_recorder.append(total_loss)
        print("Epoch: " + str(epoch) + " Loss: " + str(total_loss))

        if (epoch % 1 == 0 or epoch == args.epochs - 1):
            sampling(
                model_dir + '/' + args.dataset + '/sample_epoch_' +
                str(epoch) + ".png", fixednoise, model.decoder,
                generated_sample_number, generated_sample_size, num_chanel)
            torch.save(
                model.state_dict(), args.outdir + args.dataset + '/' +
                model_type + '_' + str(args.batch_size) + '_' +
                str(num_projection) + '_' + str(latent_size) + '_model.pth')
            torch.save(
                dis.state_dict(),
                args.outdir + args.dataset + '/' + model_type + '_' +
                str(args.batch_size) + '_' + str(num_projection) + '_' +
                str(latent_size) + '_discriminator.pth')
    np.save(
        args.outdir + args.dataset + '/' + model_type + '_' +
        str(args.batch_size) + '_' + str(num_projection) + '_' +
        str(latent_size) + '_loss.npy', loss_recorder)
if args.single_run:
    batches = [batch_num]
else:
    batches = [5, 10, 20, 30, 40, 50, 100]
    # batches = [1, 2]

for batch_num in batches:
    training_set_size = batch_num * train_loader.batch_size
    if train_cnn:
        print("Starting CNN")
        t1 = time.time()
        cnn_copy = copy.deepcopy(cnn)
        optimizer_cnn = torch.optim.Adam(cnn_copy.parameters(),
                                         lr=learning_rate)
        _, cnn_loss, cnn_acc, cnn_loss_validation = utils.train_net(
            cnn_copy, train_loader, test_loader, criterion_cnn, optimizer_cnn,
            batch_num, epoch_num)
        print("DONE TRAINING CNN: {}s".format(time.time() - t1))
        print("FINAL ACC:", cnn_acc[-1], '\n\n')

    if train_pinn:
        print("Starting PINN")
        t2 = time.time()
        pinn_copy = copy.deepcopy(pinn)
        optimizer_pinn = torch.optim.Adam(pinn_copy.parameters(),
                                          lr=learning_rate)
        after_pinn, pinn_loss, pinn_acc, pinn_loss_validation = utils.train_net(
            pinn_copy,
            train_loader,
            test_loader,
            criterion_pinn,
Beispiel #6
0
from PairImageFolder import DownsizedPairImageFolder
from EnhanceNet import EnhanceNet

from utils import train_net, save_result_img, path_check
import os

if __name__ == '__main__':

    #dataset
    train_data = DownsizedPairImageFolder('./dataset/train', transform=transforms.ToTensor())
    test_data = DownsizedPairImageFolder('./dataset/test', transform=transforms.ToTensor())

    #dataloader
    batch_size = 32
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)


    # network
    net = EnhanceNet()

    # training
    train_net(net, train_loader, test_loader, device='cuda:0')

    # save result
    dst = './result'
    f_name = 'cnn_upscale.jpg'

    path_check(dst)
    save_result_img(net, test_data, os.path.join(dst,f_name))
Beispiel #7
0
    train_loader = DataLoader(train_imgs, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_imgs, batch_size=32, shuffle=False)

    # print classes
    print(train_imgs)
    print("classes : ", train_imgs.classes)
    print("class to idx : ", train_imgs.class_to_idx)

    # pre-trained model
    net = models.resnet18(pretrained=True)

    # turn of auto_grad
    for p in net.parameters():
        p.requires_grad = False

    # change the FC layer
    fc_input_dim = net.fc.in_features
    net.fc = nn.Linear(fc_input_dim, 2)

    # if we only want the feature, we can put a IdentityLayer.
    # This is just one of many other methods.
    """
    class IdentityLayer(nn.Module):
        def forward(self, x):
            return x
    net.fc = IdentityLayer()
    """

    # start training
    train_net(net, train_loader, test_loader, n_iter=20, device='cuda:0')
Beispiel #8
0
nn['dropout2'] = layers.DropoutLayer(nn['fc2'], p = 0.5)

nn['output'] = layers.DenseLayer(nn['dropout2'], num_units=num_classes,                                 
                                 nonlinearity=nonlinearities.softmax)

network = nn['output']

y_predicted = lasagne.layers.get_output(network)
all_weights = lasagne.layers.get_all_params(network)

loss = lasagne.objectives.categorical_crossentropy(y_predicted, target_y).mean()
l2_penalty_1 = regularize_layer_params(nn['fc1'], l2)
l2_penalty_2 = regularize_layer_params(nn['fc2'], l2)
loss += l2_penalty_1 + l2_penalty_2
accuracy = lasagne.objectives.categorical_accuracy(y_predicted, target_y).mean()
updates_sgd = lasagne.updates.nesterov_momentum(loss, all_weights, learning_rate=0.01, momentum=0.9)

train_fun = theano.function([input_X, target_y], [loss, accuracy], allow_input_downcast=True, updates=updates_sgd)
test_fun  = theano.function([input_X, target_y], [loss, accuracy], allow_input_downcast=True)

print 'starting training'

conv_nn = train_net(network, train_fun, test_fun, X_train, y_train, X_valid, y_valid, num_epochs=500, batch_size=300)

final_weights = lasagne.layers.get_all_params(network)

np.save('weights_{}'.format(datetime.datetime.now()), final_weights)


Beispiel #9
0
# dataset link : http://www.robots.ox.ac.uk/~vgg/data/flowers/102/
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

from models import GNet, DNet
from utils import train_net

if __name__ == '__main__':

    img_data = ImageFolder('./dataset', transform=transforms.Compose([transforms.Resize(80),
                                                                      transforms.CenterCrop(64),
                                                                      transforms.ToTensor()]))

    batch_size = 64
    img_loader = DataLoader(img_data, batch_size=batch_size, shuffle=True)

    g = GNet().to('cuda:0')
    d = DNet().to('cuda:0')

    train_net(g, d, img_loader, batch_size=batch_size, device='cuda:0')
def main():
    SEED = 42
    torch.manual_seed(SEED)

    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--map_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="Folder containing maps")
    parser.add_argument(
        "--goal_dir",
        default=None,
        type=str,
        required=True,
        help="Folder containing goals for maps. See dataset class for info.")
    parser.add_argument(
        "--heuristic_dir",
        default=None,
        type=str,
        required=True,
        help=
        "Folder containing heurisctics for maps. See dataset class for info.")
    parser.add_argument(
        "--map_to_heuristic",
        default=None,
        type=str,
        required=True,
        help=
        "json file with maps names as keys and heuristic files as values. Note that goal and heuristic for one task should have the same names."
    )

    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: small, big")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    parser.add_argument('--alpha',
                        type=float,
                        default=0.0,
                        required=True,
                        help="Weight for gradient loss.")
    parser.add_argument(
        '--alpha1',
        type=float,
        default=1.0,
        required=True,
        help=
        "Weight for component of piece loss where output heuristic is less than minimal cost."
    )
    parser.add_argument(
        '--alpha2',
        type=float,
        default=0.0,
        required=True,
        help=
        "Weight for component of piece loss where output heuristic is more than target cost."
    )

    parser.add_argument("--batch_size",
                        default=32,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--learning_rate",
                        default=1e-3,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        '--desired_batch_size',
        type=int,
        default=32,
        help=
        "Desired batch size to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--num_train_epochs",
                        default=10,
                        type=int,
                        help="Total number of training epochs to perform.")

    args = parser.parse_args()
    alpha = args.alpha
    alpha1 = args.alpha1
    alpha2 = args.alpha2

    if args.model_type == 'small':
        model = SmallUNet()
    elif args.model_type == 'big':
        model = UNet()
    else:
        raise (ValueError, 'Model type should be in [small, big]')

    learning_rate = args.learning_rate
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = lambda output, target_map, minimal_cost: loss(
        output, target_map, minimal_cost, device, alpha, alpha1, alpha2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    exp_name = f'alpha_{alpha}_alpha1_{alpha1}_alpha2_{alpha2}'

    MAP_DIR = args.map_dir
    HEURISTIC_DIR = args.heuristic_dir
    GOAL_DIR = args.goal_dir
    map2heuristic_path = args.map_to_heuristic
    output_dir = args.output_dir

    with open(map2heuristic_path, 'r') as file:
        map2heuristic = json.load(file)

    batch_size = args.batch_size
    num_epochs = args.num_train_epochs
    desired_batch_size = args.desired_batch_size if args.desired_batch_size > batch_size else batch_size

    config = {
        'learning_rate': learning_rate,
        'alpha': alpha,
        'alpha1': alpha1,
        'alpha2': alpha2,
        'num_epochs': num_epochs,
        'batch_size': batch_size,
        'desired_batch_size': desired_batch_size
    }

    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    with open(os.path.join(output_dir, 'config.json'), 'w') as file:
        json.dump(config, file)

    dataset = MapsDataset(MAP_DIR,
                          HEURISTIC_DIR,
                          GOAL_DIR,
                          map2heuristic,
                          maps_size=(64, 64))
    train_dataset, val_dataset = random_split(dataset, [40000, 10000])
    train_batch_gen = DataLoader(train_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=cpu_count())
    val_batch_gen = DataLoader(val_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               pin_memory=True,
                               num_workers=cpu_count())

    _ = train_net(model,
                  criterion,
                  optimizer,
                  train_batch_gen,
                  val_batch_gen,
                  device,
                  num_epochs=num_epochs,
                  output_dir=output_dir,
                  desired_batch_size=desired_batch_size,
                  exp_name=exp_name)
Beispiel #11
0
# dataset : https://github.com/karpathy/char-rnn/tree/master/data/tinyshakespeare
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from ShakespeareDataset import ShakespeareDataset
from model import SequenceGenerationNet
from utils import train_net

if __name__ == '__main__':
    batch_size = 32

    text_dataset = ShakespeareDataset('./dataset/tinyshakespeare/input.txt', chunk_size=200)
    data_loader = DataLoader(text_dataset, batch_size=batch_size, shuffle=True)

    net = SequenceGenerationNet(num_embeddings=text_dataset.vocab_size,
                                embedding_dim=20,
                                hidden_size=50,
                                num_layers=2,
                                dropout=0.1)
    net = net.to('cuda:0')

    train_net(net, data_loader, text_dataset,
              n_iter=2, device='cuda:0')