Exemple #1
0
            output = model(lr_patch)
            l1_loss = L1Loss(output, hr_patch)
            p_loss = 0.0006 * perceptualLoss(output, hr_patch)
            loss = l1_loss + p_loss
            loss.backward()
            optimizer.step()
            train_loss_l1 += l1_loss.item()
            train_loss_p += p_loss.item()

        print('Train set : l1 loss: {:.4f} p loss: {:.4f} '.format(
            train_loss_l1, train_loss_p))
        print('time ', time.time() - start_time)

        # save ckpt each epoch
        utils.save_networks(model,
                            save_path=os.path.join(args.output_root,
                                                   'checkpoints',
                                                   'epoch_%d.pth' % epoch))
        scheduler.step(epoch)

        # check val performance
        if epoch % test_interval == 0:
            test_start_time = time.time()
            with torch.no_grad():
                print('------------------------')
                for module in model.children():
                    module.train(False)
                test_ite = 0
                test_psnr = 0
                test_ssim = 0
                eps = 1e-10
Exemple #2
0
def main_worker(options):
    torch.manual_seed(options['seed'])
    os.environ['CUDA_VISIBLE_DEVICES'] = options['gpu']
    use_gpu = torch.cuda.is_available()
    if options['use_cpu']: use_gpu = False

    if use_gpu:
        print("Currently using GPU: {}".format(options['gpu']))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(options['seed'])
    else:
        print("Currently using CPU")

    # Dataset
    print("{} Preparation".format(options['dataset']))
    if 'mnist' in options['dataset']:
        Data = MNIST_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'], img_size=options['img_size'])
        trainloader, testloader, outloader = Data.train_loader, Data.test_loader, Data.out_loader
    elif 'cifar10' == options['dataset']:
        Data = CIFAR10_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'], img_size=options['img_size'])
        trainloader, testloader, outloader = Data.train_loader, Data.test_loader, Data.out_loader
    elif 'svhn' in options['dataset']:
        Data = SVHN_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'], img_size=options['img_size'])
        trainloader, testloader, outloader = Data.train_loader, Data.test_loader, Data.out_loader
    elif 'cifar100' in options['dataset']:
        Data = CIFAR10_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'], img_size=options['img_size'])
        trainloader, testloader = Data.train_loader, Data.test_loader
        out_Data = CIFAR100_OSR(known=options['unknown'], dataroot=options['dataroot'], batch_size=options['batch_size'], img_size=options['img_size'])
        outloader = out_Data.test_loader
    else:
        Data = Tiny_ImageNet_OSR(known=options['known'], dataroot=options['dataroot'], batch_size=options['batch_size'], img_size=options['img_size'])
        trainloader, testloader, outloader = Data.train_loader, Data.test_loader, Data.out_loader
    
    options['num_classes'] = Data.num_classes

    # Model
    print("Creating model: {}".format(options['model']))
    if options['cs']:
        net = classifier32ABN(num_classes=options['num_classes'])
    else:
        net = classifier32(num_classes=options['num_classes'])
    feat_dim = 128

    if options['cs']:
        print("Creating GAN")
        nz, ns = options['nz'], 1
        if 'tiny_imagenet' in options['dataset']:
            netG = gan.Generator(1, nz, 64, 3)
            netD = gan.Discriminator(1, 3, 64)
        else:
            netG = gan.Generator32(1, nz, 64, 3)
            netD = gan.Discriminator32(1, 3, 64)
        fixed_noise = torch.FloatTensor(64, nz, 1, 1).normal_(0, 1)
        criterionD = nn.BCELoss()

    # Loss
    options.update(
        {
            'feat_dim': feat_dim,
            'use_gpu':  use_gpu
        }
    )

    Loss = importlib.import_module('loss.'+options['loss'])
    criterion = getattr(Loss, options['loss'])(**options)

    if use_gpu:
        net = nn.DataParallel(net).cuda()
        criterion = criterion.cuda()
        if options['cs']:
            netG = nn.DataParallel(netG, device_ids=[i for i in range(len(options['gpu'].split(',')))]).cuda()
            netD = nn.DataParallel(netD, device_ids=[i for i in range(len(options['gpu'].split(',')))]).cuda()
            fixed_noise.cuda()

    model_path = os.path.join(options['outf'], 'models', options['dataset'])
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    
    if options['dataset'] == 'cifar100':
        model_path += '_50'
        file_name = '{}_{}_{}_{}_{}'.format(options['model'], options['loss'], 50, options['item'], options['cs'])
    else:
        file_name = '{}_{}_{}_{}'.format(options['model'], options['loss'], options['item'], options['cs'])

    if options['eval']:
        net, criterion = load_networks(net, model_path, file_name, criterion=criterion)
        results = test(net, criterion, testloader, outloader, epoch=0, **options)
        print("Acc (%): {:.3f}\t AUROC (%): {:.3f}\t OSCR (%): {:.3f}\t".format(results['ACC'], results['AUROC'], results['OSCR']))

        return results

    params_list = [{'params': net.parameters()},
                {'params': criterion.parameters()}]
    
    if options['dataset'] == 'tiny_imagenet':
        optimizer = torch.optim.Adam(params_list, lr=options['lr'])
    else:
        optimizer = torch.optim.SGD(params_list, lr=options['lr'], momentum=0.9, weight_decay=1e-4)
    if options['cs']:
        optimizerD = torch.optim.Adam(netD.parameters(), lr=options['gan_lr'], betas=(0.5, 0.999))
        optimizerG = torch.optim.Adam(netG.parameters(), lr=options['gan_lr'], betas=(0.5, 0.999))

    if options['stepsize'] > 0:
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30,60,90,120])

    start_time = time.time()

    for epoch in range(options['max_epoch']):
        print("==> Epoch {}/{}".format(epoch+1, options['max_epoch']))

        if options['cs']:
            train_cs(net, netD, netG, criterion, criterionD,
                optimizer, optimizerD, optimizerG,
                trainloader, epoch=epoch, **options)

        train(net, criterion, optimizer, trainloader, epoch=epoch, **options)

        if options['eval_freq'] > 0 and (epoch+1) % options['eval_freq'] == 0 or (epoch+1) == options['max_epoch']:
            print("==> Test", options['loss'])
            results = test(net, criterion, testloader, outloader, epoch=epoch, **options)
            print("Acc (%): {:.3f}\t AUROC (%): {:.3f}\t OSCR (%): {:.3f}\t".format(results['ACC'], results['AUROC'], results['OSCR']))

            save_networks(net, model_path, file_name, criterion=criterion)
        
        if options['stepsize'] > 0: scheduler.step()

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))

    return results
Exemple #3
0
def train_single_scale(netD,
                       netG,
                       reals,
                       Gs,
                       Zs,
                       in_s,
                       NoiseAmp,
                       opt,
                       centers=None):

    real = reals[len(Gs)]
    opt.nzx = real.shape[2]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.nzy = real.shape[3]  #+(opt.ker_size-1)*(opt.num_layer)
    opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) *
                                          (opt.num_layer - 1)) * opt.stride
    pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2)
    if opt.mode == 'animation_train':
        opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer)
        opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer)
        pad_noise = 0
    m_noise = nn.ZeroPad2d(int(pad_noise))
    m_image = nn.ZeroPad2d(int(pad_image))

    alpha = opt.alpha

    fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                           device=opt.device)
    z_opt = torch.full(fixed_noise.shape, 0, device=opt.device)
    z_opt = m_noise(z_opt)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr_d,
                            betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr_g,
                            betas=(opt.beta1, 0.999))
    schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)
    schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,
                                                      milestones=[1600],
                                                      gamma=opt.gamma)

    errD2plot = []
    errG2plot = []
    D_real2plot = []
    D_fake2plot = []
    z_opt2plot = []

    for epoch in range(opt.niter):
        if (Gs == []) & (opt.mode != 'SR_train'):
            z_opt = functions.generate_noise([1, opt.nzx, opt.nzy],
                                             device=opt.device)
            z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy))
            noise_ = functions.generate_noise([1, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy))
        else:
            noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy],
                                              device=opt.device)
            noise_ = m_noise(noise_)

        ############################
        # (1) Update D network: maximize D(x) + D(G(z))
        ###########################
        for j in range(opt.Dsteps):
            # train with real
            netD.zero_grad()

            output = netD(real).to(opt.device)
            #D_real_map = output.detach()
            errD_real = -output.mean()  #-a
            errD_real.backward(retain_graph=True)
            D_x = -errD_real.item()

            # train with fake
            if (j == 0) & (epoch == 0):
                if (Gs == []) & (opt.mode != 'SR_train'):
                    prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                      0,
                                      device=opt.device)
                    in_s = prev
                    prev = m_image(prev)
                    z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy],
                                        0,
                                        device=opt.device)
                    z_prev = m_noise(z_prev)
                    opt.noise_amp = 1
                elif opt.mode == 'SR_train':
                    z_prev = in_s
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
                    prev = z_prev
                else:
                    prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                       m_noise, m_image, opt)
                    prev = m_image(prev)
                    z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec',
                                         m_noise, m_image, opt)
                    criterion = nn.MSELoss()
                    RMSE = torch.sqrt(criterion(real, z_prev))
                    opt.noise_amp = opt.noise_amp_init * RMSE
                    z_prev = m_image(z_prev)
            else:
                prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand',
                                   m_noise, m_image, opt)
                prev = m_image(prev)

            if opt.mode == 'paint_train':
                prev = functions.quant2centers(prev, centers)
                plt.imsave('%s/prev.png' % (opt.outf),
                           functions.convert_image_np(prev),
                           vmin=0,
                           vmax=1)

            if (Gs == []) & (opt.mode != 'SR_train'):
                noise = noise_
            else:
                noise = opt.noise_amp * noise_ + prev

            fake = netG(noise.detach(), prev)
            output = netD(fake.detach())
            errD_fake = output.mean()
            errD_fake.backward(retain_graph=True)
            D_G_z = output.mean().item()

            gradient_penalty = functions.calc_gradient_penalty(
                netD, real, fake, opt.lambda_grad, opt.device)
            gradient_penalty.backward()

            errD = errD_real + errD_fake + gradient_penalty
            optimizerD.step()

        errD2plot.append(errD.detach())

        ############################
        # (2) Update G network: maximize D(G(z))
        ###########################

        for j in range(opt.Gsteps):
            netG.zero_grad()
            output = netD(fake)
            #D_fake_map = output.detach()
            errG = -output.mean()
            errG.backward(retain_graph=True)
            if alpha != 0:
                loss = nn.MSELoss()
                if opt.mode == 'paint_train':
                    z_prev = functions.quant2centers(z_prev, centers)
                    plt.imsave('%s/z_prev.png' % (opt.outf),
                               functions.convert_image_np(z_prev),
                               vmin=0,
                               vmax=1)
                Z_opt = opt.noise_amp * z_opt + z_prev
                rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real)
                rec_loss.backward(retain_graph=True)
                rec_loss = rec_loss.detach()
            else:
                Z_opt = z_opt
                rec_loss = 0

            optimizerG.step()

        errG2plot.append(errG.detach() + rec_loss)
        D_real2plot.append(D_x)
        D_fake2plot.append(D_G_z)
        z_opt2plot.append(rec_loss)

        if epoch % 25 == 0 or epoch == (opt.niter - 1):
            print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter))

        if epoch % 500 == 0 or epoch == (opt.niter - 1):
            plt.imsave('%s/fake_sample.png' % (opt.outf),
                       functions.convert_image_np(fake.detach()),
                       vmin=0,
                       vmax=1)
            plt.imsave('%s/G(z_opt).png' % (opt.outf),
                       functions.convert_image_np(
                           netG(Z_opt.detach(), z_prev).detach()),
                       vmin=0,
                       vmax=1)
            #plt.imsave('%s/D_fake.png'   % (opt.outf), functions.convert_image_np(D_fake_map))
            #plt.imsave('%s/D_real.png'   % (opt.outf), functions.convert_image_np(D_real_map))
            #plt.imsave('%s/z_opt.png'    % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1)
            #plt.imsave('%s/prev.png'     %  (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1)
            #plt.imsave('%s/noise.png'    %  (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1)
            #plt.imsave('%s/z_prev.png'   % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1)

            torch.save(z_opt, '%s/z_opt.pth' % (opt.outf))

        schedulerD.step()
        schedulerG.step()

    functions.save_networks(netG, netD, z_opt, opt)
    return z_opt, in_s, netG
Exemple #4
0
def main():
    ## parse flags
    config = Options().parse()
    utils.print_opts(config)

    ## set up folders
    exp_dir = os.path.join(config.exp_dir, config.exp_name)
    model_dir = os.path.join(exp_dir, 'models')
    img_dir = os.path.join(exp_dir, 'images')
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    if config.use_tbx:
        # remove old tensorboardX logs
        logs = glob.glob(os.path.join(exp_dir, 'events.out.tfevents.*'))
        if len(logs) > 0:
            os.remove(logs[0])
        tbx_writer = SummaryWriter(exp_dir)
    else:
        tbx_writer = None

    ## initialize data loaders/generators & model
    r_loader, z_loader = get_loader(config)
    if config.solver == 'w1':
        model = W1(config, r_loader, z_loader)
    elif config.solver == 'w2':
        model = W2(config, r_loader, z_loader)
    elif config.solver == 'bary_ot':
        model = BaryOT(config, r_loader, z_loader)
    cudnn.benchmark = True
    networks = model.get_networks()
    utils.print_networks(networks)

    ## training
    ## stage 1 (dual stage) of bary_ot
    start_time = time.time()
    if config.solver == 'bary_ot':
        print("Starting: dual stage for %d iters." % config.dual_iters)
        for step in range(config.dual_iters):
            model.train_diter_only(config)
            if ((step + 1) % 100) == 0:
                stats = model.get_stats(config)
                end_time = time.time()
                stats['disp_time'] = (end_time - start_time) / 60.
                start_time = end_time
                utils.print_out(stats, step + 1, config.dual_iters, tbx_writer)
        print("dual stage iterations complete.")

    ## main training loop of w1 / w2 or stage 2 (map stage) of bary-ot
    map_iters = config.map_iters if config.solver == 'bary_ot' else config.train_iters
    if config.solver == 'bary_ot':
        print("Starting: map stage for %d iters." % map_iters)
    else:
        print("Starting training...")
    for step in range(map_iters):
        model.train_iter(config)
        if ((step + 1) % 100) == 0:
            stats = model.get_stats(config)
            end_time = time.time()
            stats['disp_time'] = (end_time - start_time) / 60.
            start_time = end_time
            utils.print_out(stats, step + 1, map_iters, tbx_writer)
        if ((step + 1) % 500) == 0:
            images = model.get_visuals(config)
            utils.visualize_iter(images, img_dir, step + 1, config)
    print("Training complete.")
    networks = model.get_networks()
    utils.save_networks(networks, model_dir)

    ## testing
    root = "./mvg_test"
    file = open(os.path.join(root, "data.pkl"), "rb")
    fixed_z = pickle.load(file)
    file.close()
    fixed_z = utils.to_var(fixed_z)
    fixed_gz = model.g(fixed_z).view(*fixed_z.size())
    utils.visualize_single(fixed_gz, os.path.join(img_dir, 'test.png'), config)
Exemple #5
0
def main():
    config = Options().parse()
    utils.print_opts(config)

    ## set up folders
    dir_string = './{0}_{1}/trial_{2}/'.format(config.solver, config.data, config.trial) if config.solver != 'w2' else \
                                    './{0}_gen{2}_{1}/trial_{3}/'.format(config.solver, config.data, config.gen, config.trial)

    exp_dir = dir_string  #os.path.join(config.exp_dir, config.exp_name)
    model_dir = os.path.join(exp_dir, 'models')
    img_dir = os.path.join(exp_dir, 'images')
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    if config.use_tbx:
        # remove old tensorboardX logs
        logs = glob.glob(os.path.join(exp_dir, 'events.out.tfevents.*'))
        if len(logs) > 0:
            os.remove(logs[0])
        tbx_writer = SummaryWriter(exp_dir)
    else:
        tbx_writer = None

    ## initialize data loaders & model
    r_loader, z_loader = get_loader(config)
    if config.solver == 'w1':
        model = W1(config, r_loader, z_loader)
    elif config.solver == 'w2':
        model = W2(config, r_loader, z_loader)
    elif config.solver == 'bary_ot':
        model = BaryOT(config, r_loader, z_loader)
    cudnn.benchmark = True
    networks = model.get_networks(config)
    utils.print_networks(networks)

    fixed_r, fixed_z = model.get_fixed_data()
    utils.visualize_single(utils.to_data(fixed_z),
                           utils.to_data(fixed_r),
                           None,
                           os.path.join(img_dir, 'data.png'),
                           data_range=(-12,
                                       12) if config.data == '8gaussians' else
                           (-6, 6))
    if not config.no_benchmark:
        print('computing discrete-OT benchmark...')
        start_time = time.time()
        cost = model.get_cost()
        discrete_tz = utils.solve_assignment(fixed_z, fixed_r, cost,
                                             fixed_r.size(0))
        print('Done in %.4f seconds.' % (time.time() - start_time))
        utils.visualize_single(utils.to_data(fixed_z), utils.to_data(fixed_r),
                               utils.to_data(discrete_tz),
                               os.path.join(img_dir, 'assignment.png'))

    ## training
    ## stage 1 (dual stage) of bary_ot
    start_time = time.time()
    if config.solver == 'bary_ot':
        print("Starting: dual stage for %d iters." % config.dual_iters)
        for step in range(config.dual_iters):
            model.train_diter_only(config)
            if ((step + 1) % 10) == 0:
                stats = model.get_stats(config)
                end_time = time.time()
                stats['disp_time'] = (end_time - start_time) / 60.
                start_time = end_time
                utils.print_out(stats, step + 1, config.dual_iters, tbx_writer)
        print("dual stage complete.")

    ## main training loop of w1 / w2 or stage 2 (map stage) of bary-ot
    map_iters = config.map_iters if config.solver == 'bary_ot' else config.train_iters
    if config.solver == 'bary_ot':
        print("Starting: map stage for %d iters." % map_iters)
    else:
        print("Starting training...")
    for step in range(map_iters):
        model.train_iter(config)
        if ((step + 1) % 10) == 0:
            stats = model.get_stats(config)
            end_time = time.time()
            stats['disp_time'] = (end_time - start_time) / 60.
            start_time = end_time
            if not config.no_benchmark:
                if config.gen:
                    stats['l2_dist/discrete_T_x--G_x'] = losses.calc_l2(
                        fixed_z, model.g(fixed_z), discrete_tz).data.item()
                else:
                    stats['l2_dist/discrete_T_x--T_x'] = losses.calc_l2(
                        fixed_z, model.get_tx(fixed_z, reverse=True),
                        discrete_tz).data.item()
            utils.print_out(stats, step + 1, map_iters, tbx_writer)
        if ((step + 1) % 10000) == 0 or step == 0:
            images = model.get_visuals(config)
            utils.visualize_iter(
                images,
                img_dir,
                step + 1,
                config,
                data_range=(-12, 12) if config.data == '8gaussians' else
                (-6, 6))
    print("Training complete.")
    networks = model.get_networks(config)
    utils.save_networks(networks, model_dir)
Exemple #6
0
def main():
    torch.manual_seed(options['seed'])
    os.environ['CUDA_VISIBLE_DEVICES'] = options['gpu']
    use_gpu = torch.cuda.is_available()
    if options['use_cpu']: use_gpu = False

    feat_dim = 2 if 'cnn' in options['model'] else 512

    options.update(
        {
            'feat_dim': feat_dim,
            'use_gpu': use_gpu
        }
    )

    if use_gpu:
        print("Currently using GPU: {}".format(options['gpu']))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(options['seed'])
    else:
        print("Currently using CPU")

    dataset = datasets.create(options['dataset'], **options)
    out_dataset = datasets.create(options['out_dataset'], **options)

    trainloader, testloader = dataset.trainloader, dataset.testloader
    outloader = out_dataset.testloader

    options.update(
        {
            'num_classes': dataset.num_classes
        }
    )

    print("Creating model: {}".format(options['model']))
    if 'cnn' in options['model']:
        net = ConvNet(num_classes=dataset.num_classes)
    else:
        if options['cs']:
            net = resnet34ABN(num_classes=dataset.num_classes, num_bns=2)
        else:
            net = ResNet34(dataset.num_classes)


    if options['cs']:
        print("Creating GAN")
        nz = options['nz']
        netG = gan.Generator32(1, nz, 64, 3) # ngpu, nz, ngf, nc
        netD = gan.Discriminator32(1, 3, 64) # ngpu, nc, ndf
        fixed_noise = torch.FloatTensor(64, nz, 1, 1).normal_(0, 1)
        criterionD = nn.BCELoss()

    Loss = importlib.import_module('loss.'+options['loss'])
    criterion = getattr(Loss, options['loss'])(**options)

    if use_gpu:
        net = nn.DataParallel(net, device_ids=[i for i in range(len(options['gpu'].split(',')))]).cuda()
        criterion = criterion.cuda()
        if options['cs']:
            netG = nn.DataParallel(netG, device_ids=[i for i in range(len(options['gpu'].split(',')))]).cuda()
            netD = nn.DataParallel(netD, device_ids=[i for i in range(len(options['gpu'].split(',')))]).cuda()
            fixed_noise.cuda()
    
    model_path = os.path.join(options['outf'], 'models', options['dataset'])
    file_name = '{}_{}_{}_{}_{}'.format(options['model'], options['dataset'], options['loss'], str(options['weight_pl']), str(options['cs']))
    if options['eval']:
        net, criterion = load_networks(net, model_path, file_name, criterion=criterion)
        results = test(net, criterion, testloader, outloader, epoch=0, **options)
        print("Acc (%): {:.3f}\t AUROC (%): {:.3f}\t OSCR (%): {:.3f}\t".format(results['ACC'], results['AUROC'], results['OSCR']))
        return

    params_list = [{'params': net.parameters()},
                {'params': criterion.parameters()}]
    optimizer = torch.optim.Adam(params_list, lr=options['lr'])
    if options['cs']:
        optimizerD = torch.optim.Adam(netD.parameters(), lr=options['gan_lr'], betas=(0.5, 0.999))
        optimizerG = torch.optim.Adam(netG.parameters(), lr=options['gan_lr'], betas=(0.5, 0.999))
 
    if options['stepsize'] > 0:
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 90, 120])

    start_time = time.time()

    score_now = 0.0
    for epoch in range(options['max_epoch']):
        print("==> Epoch {}/{}".format(epoch+1, options['max_epoch']))

        if options['cs']:
            train_cs(net, netD, netG, criterion, criterionD,
                optimizer, optimizerD, optimizerG,
                trainloader, epoch=epoch, **options)

        train(net, criterion, optimizer, trainloader, epoch=epoch, **options)

        if options['eval_freq'] > 0 and (epoch+1) % options['eval_freq'] == 0 or (epoch+1) == options['max_epoch']:
            print("==> Test")
            results = test(net, criterion, testloader, outloader, epoch=epoch, **options)
            print("Acc (%): {:.3f}\t AUROC (%): {:.3f}\t OSCR (%): {:.3f}\t".format(results['ACC'], results['AUROC'], results['OSCR']))

            save_networks(net, model_path, file_name, criterion=criterion)
            if options['cs']: 
                save_GAN(netG, netD, model_path, file_name)
                fake = netG(fixed_noise)
                GAN_path = os.path.join(model_path, 'samples')
                mkdir_if_missing(GAN_path)
                vutils.save_image(fake.data, '%s/gan_samples_epoch_%03d.png'%(GAN_path, epoch), normalize=True)

        if options['stepsize'] > 0: scheduler.step()

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))