示例#1
0
文件: train.py 项目: lwshen/DCGAN
    dropout_value = config.dropout_value

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    train_loader = DataLoader(dataset=dataloader.xDataSet(),
                              batch_size=config.batch_size,
                              num_workers=4,
                              drop_last=True)

    gen = model.generater(dropout_value=dropout_value).to(device)
    dis = model.discriminator(alpha=lrelu_alpha,
                              dropout_value=dropout_value).to(device)
    g_optimizer = torch.optim.Adam(gen.parameters(), lr=config.lr)
    d_optimizer = torch.optim.Adam(dis.parameters(), lr=config.lr)
    loss = nn.BCELoss()
    sloss = ssim_loss.SSIM()
    l1loss = nn.L1Loss()

    # print(gen)
    # print(dis)

    g_loss = []
    d_loss = []

    if not os.path.exists('./model'):
        os.mkdir('./model')

    for epoch in range(config.epoch):
        gen_data = 0
        # nk_img = 0
        for batch_index, (NK_img, OK_img) in enumerate(train_loader):
示例#2
0
def main(run):
    print(opt)
    num_gpus = torch.cuda.device_count()
    print("available gpus are", num_gpus)
    num_cpus = multiprocessing.cpu_count()
    print("available cpus are", num_cpus)

    save_dir = './result/{}/{}{}x{:.0e}/FULLTRAIN_SUPER_R{}_RAN{}/'.format(
        opt.train_dir, opt.model, opt.upscale_factor, opt.lr,
        opt.inter_frequency, opt.ranlevel)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    res_dir = save_dir + str(run) + '/'
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    cuda = opt.cuda
    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    torch.manual_seed(opt.seed)
    torch.backends.cudnn.enabled = True
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    np.random.seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    print('===> Loading datasets')

    train_dir = "./data/{}/".format(opt.train_dir)
    train_set = TrainDatasetFromFolder(
        train_dir,
        is_gray=True,
        random_scale=True,
        crop_size=opt.upscale_factor * opt.patchsize,
        rotate=True,
        fliplr=True,
        fliptb=True,
        scale_factor=opt.upscale_factor,
        bic_inp=True if opt.model == 'vdsr' else False)
    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=opt.threads,
                                      batch_size=opt.batchSize,
                                      shuffle=True)

    val_set = TestDatasetFromFolder(
        train_dir,
        is_gray=True,
        scale_factor=opt.upscale_factor,
        bic_inp=True if opt.model == 'vdsr' else False)
    validating_data_loader = DataLoader(dataset=val_set,
                                        num_workers=opt.threads,
                                        batch_size=opt.testBatchSize,
                                        shuffle=False)

    if opt.model == 'vdsr':
        model_mse = vdsr.Net(num_channels=1, base_filter=64, num_residuals=18)
        model_l1 = vdsr.Net(num_channels=1, base_filter=64, num_residuals=18)
        model_ssim = vdsr.Net(num_channels=1, base_filter=64, num_residuals=18)
    else:
        model_mse = edsr.Net(num_channels=1,
                             base_filter=64,
                             num_residuals=18,
                             scale=opt.upscale_factor)
        model_l1 = edsr.Net(num_channels=1,
                            base_filter=64,
                            num_residuals=18,
                            scale=opt.upscale_factor)
        model_ssim = edsr.Net(num_channels=1,
                              base_filter=64,
                              num_residuals=18,
                              scale=opt.upscale_factor)

    if opt.resume:
        model_mse = load_model(res_dir + str(opt.alpha) +
                               "model_{}_{}_epoch_{}.pth".format(
                                   'mse', opt.upscale_factor, opt.nEpochs))
        model_l1 = load_model(res_dir + str(opt.alpha) +
                              "model_{}_{}_epoch_{}.pth".format(
                                  'l1', opt.upscale_factor, opt.nEpochs))
        model_ssim = load_model(res_dir + str(opt.alpha) +
                                "model_{}_{}_epoch_{}.pth".format(
                                    'ssim', opt.upscale_factor, opt.nEpochs))

    print('===> Building criterions')
    criterion_mse = nn.MSELoss()
    criterion_l1 = nn.L1Loss()
    criterion_ssim = ssim_loss.SSIM(size_average=False)

    print('===> Building optimizers')
    optimizer_mse = optim.Adam(model_mse.parameters(), lr=opt.lr)
    optimizer_l1 = optim.Adam(model_l1.parameters(), lr=opt.lr)
    optimizer_ssim = optim.Adam(model_ssim.parameters(), lr=opt.lr)

    m1 = Train_op(model_mse, optimizer_mse, criterion_mse, 'mse', opt.lr,
                  opt.upscale_factor, opt.cuda)
    m2 = Train_op(model_l1, optimizer_l1, criterion_l1, 'l1', opt.lr,
                  opt.upscale_factor, opt.cuda)
    m3 = Train_op(model_ssim, optimizer_ssim, criterion_ssim, 'ssim', opt.lr,
                  opt.upscale_factor, opt.cuda)

    models = [m1, m2, m3]

    x_label = []
    y_label = []

    print('===> start training')

    for epoch in range(0, opt.nEpochs + 1, opt.inter_frequency):
        tick_time = time.time()
        print('running epoch {}'.format(epoch))

        for m in models:
            lr = opt.lr * (opt.decay_rate**(epoch // opt.step))
            m.update_lr(lr)
            print('epoch {}, learning rate is {}'.format(epoch, lr))

        update_loss = interchange_im(models, training_data_loader,
                                     validating_data_loader, opt.alpha)

        print('evaluated loss is', update_loss)
        x_label.append(epoch)
        y_label.append(update_loss)

        print('this epoch cost {} seconds.'.format(time.time() - tick_time))

        if epoch % (opt.nEpochs // 10) == 0:
            for m in models:
                m.checkpoint(epoch, res_dir, prefix=str(opt.alpha))

    for m in models:
        m.checkpoint(epoch, res_dir, prefix=str(opt.alpha))

    # save obtained losses
    x_label = np.asarray(x_label)
    y_label = np.asarray(y_label).transpose()
    output = np.insert(y_label, 0, x_label, axis=0)

    if len(models) > 1:
        np.savetxt(res_dir + str(opt.alpha) + opt.save_file,
                   output,
                   fmt='%3.5f')
    else:
        np.savetxt(res_dir + 'loss_' + models[0].name + '.txt',
                   output,
                   fmt='%3.5f')
示例#3
0
def train(opt):
    # ----------------------------------------
    #       Network training parameters
    # ----------------------------------------

    # cudnn benchmark
    cudnn.benchmark = opt.cudnn_benchmark

    # Loss functions
    criterion_L1 = torch.nn.L1Loss().cuda()
    ssimLoss = ssim_loss.SSIM().cuda()

    # Initialize Generator
    generator = utils.create_generator(opt)

    # To device
    if opt.multi_gpu:
        generator = nn.DataParallel(generator)
        generator = generator.cuda()

    else:
        generator = generator.cuda()

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2),
                                   weight_decay=opt.weight_decay)

    # Learning rate decrease
    def adjust_learning_rate(opt, epoch, iteration, optimizer):
        #Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
        if opt.lr_decrease_mode == 'epoch':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(epoch // opt.lr_decrease_epoch))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        if opt.lr_decrease_mode == 'iter':
            lr = opt.lr_g * (opt.lr_decrease_factor
                             **(iteration // opt.lr_decrease_iter))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # Save the model if pre_train == True
    def save_model(net, epoch, opt):
        """Save the model at "checkpoint_interval" and its multiple"""
        if opt.multi_gpu == True:
            if epoch % opt.save_by_epoch == 0:
                torch.save(
                    net.module, './model/epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))
        else:
            if epoch % opt.save_by_epoch == 0:
                torch.save(
                    net, './model/epoch%d_batchsize%d.pth' %
                    (epoch, opt.batch_size))
                print('The trained model is successfully saved at epoch %d' %
                      (epoch))

    # ----------------------------------------
    #             Network dataset
    # ----------------------------------------

    # Define the dataset
    trainset = dataset.UDCDataset(opt)
    #testset = dataset.UDCValidDataset(opt)
    print('The overall number of images:', len(trainset))

    # Define the dataloader
    dataloader = DataLoader(trainset,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers,
                            pin_memory=True)
    #test_dataloader = DataLoader(testset, batch_size = 1, pin_memory = True)
    # ----------------------------------------
    #                 Training
    # ----------------------------------------

    # Count start time
    prev_time = time.time()

    # For loop training
    for epoch in range(opt.epochs):
        avg_l1_loss = 0
        avg_ssim_loss = 0
        avg_cs_ColorLoss = 0
        avg_l1_loss_lf = 0
        avg_ssim_loss_lf = 0
        generator.train()
        for i, (true_input, true_target) in enumerate(dataloader):

            # To device
            true_input = true_input.cuda()
            true_target = true_target.cuda()

            # Train Generator
            optimizer_G.zero_grad()
            fake_target = generator(true_input)

            # L1 Loss
            Pixellevel_L1_Loss = criterion_L1(fake_target, true_target)

            fake_target = fake_target * 0.5 + 0.5
            true_target = true_target * 0.5 + 0.5
            ssim_PixelLoss = 1 - ssimLoss(fake_target, true_target)

            # Overall Loss and optimize
            loss = Pixellevel_L1_Loss + ssim_PixelLoss

            avg_l1_loss += Pixellevel_L1_Loss
            avg_ssim_loss += ssim_PixelLoss
            loss.backward()
            optimizer_G.step()

            # Determine approximate time left
            iters_done = epoch * len(dataloader) + i
            iters_left = opt.epochs * len(dataloader) - iters_done
            time_left = datetime.timedelta(seconds=iters_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            print(
                "\r[Epoch %d/%d] [Batch %d/%d] [Pixellevel L1 Loss: %.4f] [Pixellevel L1 Loss LowFreq: %.4f] [ssim Loss: %.4f] [ssim Loss LowFreq: %.4f] Time_left: %s"
                % ((epoch + 1), opt.epochs, i, len(dataloader),
                   Pixellevel_L1_Loss.item(), 0, ssim_PixelLoss.item(), 0,
                   time_left))

            # Save model at certain epochs or iterations
        save_model(generator, (epoch + 1), opt)
        #valid(generator,test_dataloader,(epoch + 1),opt)
        # Learning rate decrease at certain epochs
        adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G)
        avg_l1_loss = avg_l1_loss / (i + 1)
        avg_ssim_loss = avg_ssim_loss / (i + 1)

        f = open("log.txt", "a")
        f.write('epoch: ' + str(epoch) + ' avg l1 =' +
                str(avg_l1_loss.item()) + ' avg l1 LowFreq =' + str(0) +
                ' avg ssim = ' + str(avg_ssim_loss.item()) +
                ' avg ssim LowFreq = ' + str(0) + '\n')
        f.close()