Esempio n. 1
0
def train(it_id):
    while True:
        for idx, (image_a,
                  image_b) in enumerate(zip(train_loader_a, train_loader_b)):
            # print(idx)

            # obtain input image pairs
            image_a = image_a.cuda().detach() if torch.cuda.is_available(
            ) else image_a.detach()
            image_b = image_b.cuda().detach() if torch.cuda.is_available(
            ) else image_b.detach()

            # Main training code
            model.dis_update(image_a, image_b, config)
            model.gen_update(image_a, image_b, config)

            # Updating lr
            model.dis_scheduler.step()
            model.gen_scheduler.step()

            # Dump training stats in log file
            if (it_id + 1) % config.log_iter == 0:
                write_loss(it_id, model, train_writer)

            # Save network weights
            if (it_id + 1) % config.snapshot_save_iter == 0:
                model.save(checkpoint_directory, it_id)

            it_id += 1
            if it_id + 1 >= max_iter:
                sys.exit('Finish training')
Esempio n. 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='configs/OT3+++R.yaml', help='Path to the config file.')
    parser.add_argument('--output_path', type=str, default='.', help="outputs path")
    parser.add_argument("--resume",default='True', action="store_true") #change to True is you need to retrain from pre-train model
    opts = parser.parse_args()

    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(opts.config)

    # dataset set up
    dataset = My3DDataset(opts=config)
    train_loader = DataLoader(dataset=dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['nThreads'])


    config['vgg_model_path'] = opts.output_path

    trainer = Models(config)
    trainer.cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/outputs/logs", model_name))
    checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
    shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory, hyperparameters=config,need_opt=False) if opts.resume else 0
    max_iter = int(config['n_ep']* len(dataset)/config['batch_size'])+1

    while True:
        for it,out_data  in enumerate(train_loader):
            for j in range(len(out_data)):
                out_data[j] = out_data[j].cuda().detach()
            if(config['models_name']=='dynamic_human' ):
                Xa_out, Xb_out, Yb_out, Xb_prev_out, Xb_next_out, Xa_mask, Yb_mask, rand_y_out, rand_y_mask=out_data
            trainer.update_learning_rate()
            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dynamic_gen_update(Xa_out, Xb_out, Yb_out,Xb_prev_out,Xb_next_out,
                                               Xa_mask, Yb_mask,rand_y_out, rand_y_mask, config)
                #torch.cuda.synchronize()
            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            if (iterations ) % config['image_display_iter'] == 0:
                write_image2display(iterations, trainer, train_writer)

            # Save network weights
            if (iterations+1 ) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)
            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')
Esempio n. 3
0
    def train(self, first_epoch=0):
        global_it = first_epoch * len(self.data_loader)
        num_epochs = self.config['modalities_encoder']['epochs']
        for epoch_counter in range(first_epoch, num_epochs):
            with Timer("Elapsed time for epoch: %f"):
                for it, ((xis, xjs), _, _) in enumerate(self.data_loader):
                    self.opt.zero_grad()

                    xis = xis.to(self.config['device'])
                    xjs = xjs.to(self.config['device'])

                    self.loss_enc_contrastive = self.model.calc_nt_xent_loss(
                        xis, xjs)
                    self.loss_enc_contrastive.backward()
                    self.opt.step()

                    print(
                        "Epoch: {curr_epoch}/{total_epochs} | Iteration: {curr_iter}/{total_iter} | Loss: {curr_loss}"
                        .format(curr_epoch=epoch_counter + 1,
                                total_epochs=num_epochs,
                                curr_iter=str(global_it + 1).zfill(8),
                                total_iter=str(
                                    len(self.data_loader) *
                                    num_epochs).zfill(8),
                                curr_loss=self.loss_enc_contrastive))

                    # Logging loss
                    if global_it % self.config['logger']['log_loss'] == 0:
                        write_loss(global_it, self, self.logger)

                    global_it += 1

            if epoch_counter % self.config['logger'][
                    'checkpoint_modalities_encoder_every'] == 0:
                self.save(self.config['logger']['checkpoint_dir'],
                          epoch_counter)

            if epoch_counter >= 10:
                self.scheduler.step()
Esempio n. 4
0
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size,
                              image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size,
                              image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')
Esempio n. 5
0
def main(argv):
    (opts, args) = parser.parse_args(argv)
    cudnn.benchmark = True
    model_name = os.path.splitext(os.path.basename(opts.config))[0]

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']

    # Setup model and data loader
    trainer = MUNIT_Trainer(config)
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    test_display_images_a = Variable(torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda(),
                                     volatile=True)
    test_display_images_b = Variable(torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda(),
                                     volatile=True)
    train_display_images_a = Variable(torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda(),
                                      volatile=True)
    train_display_images_b = Variable(torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda(),
                                      volatile=True)

    # Setup logger and output folders
    train_writer = tensorboard.SummaryWriter(os.path.join(
        opts.log, model_name))
    output_directory = os.path.join(opts.outputs, model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(izip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = Variable(images_a.cuda()), Variable(
                images_b.cuda())

            # Main training code
            trainer.dis_update(images_a, images_b, config)
            trainer.gen_update(images_a, images_b, config)

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                # Test set images
                image_outputs = trainer.sample(test_display_images_a,
                                               test_display_images_b)
                write_images(
                    image_outputs, display_size,
                    '%s/gen_test%08d.jpg' % (image_directory, iterations + 1))
                # Train set images
                image_outputs = trainer.sample(train_display_images_a,
                                               train_display_images_b)
                write_images(
                    image_outputs, display_size,
                    '%s/gen_train%08d.jpg' % (image_directory, iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')
            if (iterations + 1) % config['image_save_iter'] == 0:
                image_outputs = trainer.sample(test_display_images_a,
                                               test_display_images_b)
                write_images(image_outputs, display_size,
                             '%s/gen.jpg' % image_directory)

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                return
Esempio n. 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='configs/edges2handbags_folder.yaml',
                        help='Path to the config file.')
    parser.add_argument('--output_path',
                        type=str,
                        default='.',
                        help="outputs path")
    #resume option => [, default='730000']
    parser.add_argument("--resume", default='150000', action="store_true")
    parser.add_argument('--trainer',
                        type=str,
                        default='MUNIT',
                        help="MUNIT|UNIT")
    opts = parser.parse_args()

    cudnn.benchmark = True

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader
    if opts.trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts.trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    train_display_images_a = torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    train_writer = tensorboardX.SummaryWriter(
        os.path.join(opts.output_path + "/logs", model_name))
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory,
                                hyperparameters=config) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size,
                              image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size,
                              image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')

            if (iterations + 1) % config['image_display_iter'] == 0:
                with torch.no_grad():
                    image_outputs = trainer.sample(train_display_images_a,
                                                   train_display_images_b)
                write_2images(image_outputs, display_size, image_directory,
                              'train_current')

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')
    test = normalize_arr_of_imgs(test.cuda()).permute(0,3,1,2)
    torch.backends.cudnn.benchmark = True
    # Start training
    for step in tqdm(range(initial_step, opts.max_iter+1), initial=initial_step, total=opts.max_iter, ncols=64, mininterval = 2):
        # Get batch from the queue with batches q, if the last is non-empty.
        while q_art.empty() or q_content.empty():
            pass
        batch_art = normalize_arr_of_imgs(torch.tensor(q_art.get()['image'], requires_grad=False).cuda()).permute(0,3,1,2).requires_grad_()
        batch_content = normalize_arr_of_imgs(torch.tensor(q_content.get()['image'], requires_grad=False).cuda()).permute(0,3,1,2).requires_grad_()
        # Training update
        trainer.update_learning_rate()
        discr_success = trainer.update(batch_art, batch_content, opts, discr_success, alpha, discr_success >= win_rate)

        # Dump training stats in log file
        if step % 10 == 0:
            write_loss(step, trainer, train_writer)
        # Save network weights
        if (step+1) % opts.save_freq == 0:
            trainer.save(checkpoint_directory, step)
        if step % 50 == 0:
            print("Iteration: %08d/%08d, dloss = %.8s, gloss = %.8s, discr_success = %.5s" % (step, opts.max_iter, trainer.discr_loss.item(), trainer.gener_loss.item(), discr_success))
        # Write images
        if (step+1) % 100 == 0:
            del batch_art, batch_content
            torch.cuda.empty_cache()
            with torch.no_grad():
                samp = trainer.sample(test)
                image_outputs = [denormalize_arr_of_imgs(samp[0]), denormalize_arr_of_imgs(samp[1])]
            write_2images(image_outputs, opts.display_size, image_directory, 'test_%08d' % (step + 1))
            del samp, image_outputs
            torch.cuda.empty_cache()
Esempio n. 8
0
while True:
    for it, (images_a,
             images_b) in enumerate(zip(train_loader_a, train_loader_b)):
        trainer.update_learning_rate()
        images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()

        with Timer("Elapsed time in update: %f"):
            # Main training code
            trainer.dis_update(images_a, images_b, config)
            trainer.gen_update(images_a, images_b, config)
            torch.cuda.synchronize()

        # Dump training stats in log file
        if (iterations + 1) % config["log_iter"] == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer, comet_exp=comet_exp)

        # Write images
        if (iterations + 1) % config["image_save_iter"] == 0:
            with torch.no_grad():
                test_image_outputs = trainer.sample(test_display_images_a,
                                                    test_display_images_b)
                train_image_outputs = trainer.sample(train_display_images_a,
                                                     train_display_images_b)
            write_2images(
                test_image_outputs,
                display_size,
                image_directory,
                "test_%08d" % (iterations + 1),
                comet_exp=comet_exp,
            )
Esempio n. 9
0
                                   labels_a,
                                   labels_b,
                                   config,
                                   iterations,
                                   num_gpu=1)

            torch.cuda.synchronize()

        # Dump training stats in log file
        # 打印训练日志
        if (iterations + 1) % config['log_iter'] == 0:
            print("\033[1m Epoch: %02d Iteration: %08d/%08d \033[0m" %
                  (nepoch, iterations + 1, max_iter),
                  end=" ")
            if num_gpu == 1:
                write_loss(iterations, trainer, train_writer)
            else:
                write_loss(iterations, trainer.module, train_writer)

        # Write images
        # 达到迭代次数,进行图片保存
        if (iterations + 1) % config['image_save_iter'] == 0:
            with torch.no_grad():
                if num_gpu > 1:
                    test_image_outputs = trainer.module.sample(
                        test_display_images_a, test_display_images_b)
                else:
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
            write_2images(test_image_outputs, display_size, image_directory,
                          'test_%08d' % (iterations + 1))
def main():
    cudnn.benchmark = True
    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader
    trainer = UNIT_Trainer(config)
    if torch.cuda.is_available():
        trainer.cuda(config['gpuID'])
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    writer = SummaryWriter(os.path.join(opts.output_path + "/logs",
                                        model_name))
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    print('start training !!')
    # Start training
    iterations = trainer.resume(checkpoint_directory,
                                hyperparameters=config) if opts.resume else 0

    TraindataA = data_prefetcher(train_loader_a)
    TraindataB = data_prefetcher(train_loader_b)
    testdataA = data_prefetcher(test_loader_a)
    testdataB = data_prefetcher(test_loader_b)

    while True:
        dataA = TraindataA.next()
        dataB = TraindataB.next()
        if dataA is None or dataB is None:
            TraindataA = data_prefetcher(train_loader_a)
            TraindataB = data_prefetcher(train_loader_b)
            dataA = TraindataA.next()
            dataB = TraindataB.next()
        with Timer("Elapsed time in update: %f"):
            # Main training code
            for _ in range(3):
                trainer.content_update(dataA, dataB, config)
            trainer.dis_update(dataA, dataB, config)
            trainer.gen_update(dataA, dataB, config)
            # torch.cuda.synchronize()
        trainer.update_learning_rate()
        # Dump training stats in log file
        if (iterations + 1) % config['log_iter'] == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, writer)
        if (iterations + 1) % config['image_save_iter'] == 0:
            testa = testdataA.next()
            testb = testdataB.next()
            if dataA is None or dataB is None or dataA.size(
                    0) != display_size or dataB.size(0) != display_size:
                testdataA = data_prefetcher(test_loader_a)
                testdataB = data_prefetcher(test_loader_b)
                testa = testdataA.next()
                testb = testdataB.next()
            with torch.no_grad():
                test_image_outputs = trainer.sample(testa, testb)
                train_image_outputs = trainer.sample(dataA, dataB)
            if test_image_outputs is not None and train_image_outputs is not None:
                write_2images(test_image_outputs, display_size,
                              image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size,
                              image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')

        if (iterations + 1) % config['image_display_iter'] == 0:
            with torch.no_grad():
                image_outputs = trainer.sample(dataA, dataB)
            if image_outputs is not None:
                write_2images(image_outputs, display_size, image_directory,
                              'train_current')

            # Save network weights
        if (iterations + 1) % config['snapshot_save_iter'] == 0:
            trainer.save(checkpoint_directory, iterations)

        iterations += 1
        if iterations >= max_iter:
            writer.close()
            sys.exit('Finish training')
            if not os.path.exists(img_iter_dir):
                os.makedirs(img_iter_dir)
            tester.test(config['input_option'], trainer.model, img_iter_dir)
            trainer.train()

        #with utils.Timer("Elapsed time in update: %f"):
        trainer.encoder_update(images, gt_2d, gt_3d, mask, valid_3d)
        trainer.update_lr()
        torch.cuda.synchronize(
        )  # the code synchronize gpu and cpu process , ensuring the accuracy of time measure

        # Dump training stats in log file
        if (iterations + 1) % config['print_loss_iter'] == 0:
            print "Iteration: %08d/%08d, " % (iterations + 1, max_iter),
            trainer.print_losses()
            utils.write_loss(iterations, trainer, train_writer)

        if (iterations +
                1) == 100 or (iterations + 1) % config['test_iter'] == 0:
            trainer.eval()
            new_trainloader, new_testLoader = utils.get_data_loader(
                config, isPretrain=False)
            train_num, train_loss = sample(trainer, new_trainloader,
                                           config['batch_size'],
                                           config['test_num'])
            test_num, test_loss = sample(trainer, new_testLoader,
                                         config['batch_size'],
                                         config['test_num'])
            trainer.train()
            #print 'test on %d iamges in trainset, %d iamges in testset'%(train_num, test_num)
            loss_log[0].append(iterations + 1)
Esempio n. 12
0
def main():
    from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
    import argparse
    from torch.autograd import Variable
    from trainer import MUNIT_Trainer, UNIT_Trainer
    import torch.backends.cudnn as cudnn
    import torch

    # try:
    #     from itertools import izip as zip
    # except ImportError:  # will be 3.x series
    #     pass

    import os
    import sys
    import tensorboardX
    import shutil

    os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

    parser = argparse.ArgumentParser()
    parser.add_argument('--config',
                        type=str,
                        default='configs/edges2handbags_folder.yaml',
                        help='Path to the config file.')
    parser.add_argument('--output_path',
                        type=str,
                        default='.',
                        help="outputs path")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument('--trainer',
                        type=str,
                        default='MUNIT',
                        help="MUNIT|UNIT")
    opts = parser.parse_args()

    cudnn.benchmark = True
    '''
    Note: https://www.pytorchtutorial.com/when-should-we-set-cudnn-benchmark-to-true/
        大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题
        1.  如果网络的输入数据维度或类型上变化不大,设置  torch.backends.cudnn.benchmark = true  可以增加运行效率;
        2.  如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每次都会去寻找一遍最优配置,这样反而会降低运行效率。
    '''

    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    display_size = config['display_size']
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader
    if opts.trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts.trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(
        config)
    train_display_images_a = torch.stack(
        [train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack(
        [train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack(
        [test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack(
        [test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    train_writer = tensorboardX.SummaryWriter(
        os.path.join(opts.output_path + "/logs", model_name))
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(
        output_directory)
    shutil.copy(opts.config, os.path.join(
        output_directory, 'config.yaml'))  # copy config file to output folder

    # Start training
    iterations = trainer.resume(checkpoint_directory,
                                hyperparameters=config) if opts.resume else 0
    while True:
        for it, (images_a,
                 images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            trainer.update_learning_rate()
            images_a, images_b = images_a.cuda().detach(), images_b.cuda(
            ).detach()

            with Timer("Elapsed time in update: %f"):
                # Main training code
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(
                        test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(
                        train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size,
                              image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size,
                              image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1,
                           config['image_save_iter'], 'images')

            if (iterations + 1) % config['image_display_iter'] == 0:
                with torch.no_grad():
                    image_outputs = trainer.sample(train_display_images_a,
                                                   train_display_images_b)
                write_2images(image_outputs, display_size, image_directory,
                              'train_current')

            # Save network weights
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')
Esempio n. 13
0
File: train.py Progetto: phonx/MUNIT
# Start training
iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
while True:
    for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
        trainer.update_learning_rate()
        images_a, images_b = Variable(images_a.cuda()), Variable(images_b.cuda())

        # Main training code
        trainer.dis_update(images_a, images_b, config)
        trainer.gen_update(images_a, images_b, config)

        # Dump training stats in log file
        if (iterations + 1) % config['log_iter'] == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer)

        # Write images
        if (iterations + 1) % config['image_save_iter'] == 0:
            # Test set images
            image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
            write_2images(image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
            # Train set images
            image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
            # HTML
            write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')

        if (iterations + 1) % config['image_display_iter'] == 0:
            train_display_images_a = Variable(torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda(), volatile=True)
            train_display_images_b = Variable(torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda(), volatile=True)
Esempio n. 14
0
def main(opts):
    # Load experiment setting
    config = get_config(opts.config)
    max_iter = config['max_iter']
    # Override the batch size if specified.
    if opts.batch_size != 0:
        config['batch_size'] = opts.batch_size

    trainer = Trainer(config)
    trainer.cuda()
    if opts.multigpus:
        ngpus = torch.cuda.device_count()
        config['gpus'] = ngpus
        print("Number of GPUs: %d" % ngpus)
        trainer.model = torch.nn.DataParallel(trainer.model,
                                              device_ids=range(ngpus))
    else:
        config['gpus'] = 1

    loaders = get_train_loaders(config)
    train_content_loader = loaders[0]
    train_class_loader = loaders[1]
    test_content_loader = loaders[2]
    test_class_loader = loaders[3]

    # Setup logger and output folders
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    train_writer = SummaryWriter(
        os.path.join(opts.output_path + "/logs", model_name))
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = make_result_folders(
        output_directory)
    shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml'))

    iterations = trainer.resume(checkpoint_directory,
                                hp=config,
                                multigpus=opts.multigpus) if opts.resume else 0

    while True:
        for it, (co_data, cl_data) in enumerate(
                zip(train_content_loader, train_class_loader)):
            with Timer("Elapsed time in update: %f"):
                d_acc = trainer.dis_update(co_data, cl_data, config)
                g_acc = trainer.gen_update(co_data, cl_data, config,
                                           opts.multigpus)
                torch.cuda.synchronize()
                print('D acc: %.4f\t G acc: %.4f' % (d_acc, g_acc))

            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            if ((iterations + 1) % config['image_save_iter'] == 0
                    or (iterations + 1) % config['image_display_iter'] == 0):
                if (iterations + 1) % config['image_save_iter'] == 0:
                    key_str = '%08d' % (iterations + 1)
                    write_html(output_directory + "/index.html",
                               iterations + 1, config['image_save_iter'],
                               'images')
                else:
                    key_str = 'current'
                with torch.no_grad():
                    for t, (val_co_data, val_cl_data) in enumerate(
                            zip(train_content_loader, train_class_loader)):
                        if t >= opts.test_batch_size:
                            break
                        val_image_outputs = trainer.test(
                            val_co_data, val_cl_data, opts.multigpus)
                        write_1images(val_image_outputs, image_directory,
                                      'train_%s_%02d' % (key_str, t))
                    for t, (test_co_data, test_cl_data) in enumerate(
                            zip(test_content_loader, test_class_loader)):
                        if t >= opts.test_batch_size:
                            break
                        test_image_outputs = trainer.test(
                            test_co_data, test_cl_data, opts.multigpus)
                        write_1images(test_image_outputs, image_directory,
                                      'test_%s_%02d' % (key_str, t))

            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations, opts.multigpus)
                print('Saved model at iteration %d' % (iterations + 1))

            iterations += 1
            if iterations >= max_iter:
                print("Finish Training")
                sys.exit(0)
Esempio n. 15
0
def main(config, logger):

    print("Start extracting modalities...\n")

    modalities_encoder_trainer = ModalitiesEncoderTrainer(config, logger)
    encoder_first_epoch = modalities_encoder_trainer.load(
        config['logger']['checkpoint_dir']) if config['resume'] else 0
    modalities_encoder_trainer.train(encoder_first_epoch)

    modalities_extraction_loader = get_modalities_extraction_loader(config)
    modalities_extractor = ModalitiesExtractor(config)
    modalities = modalities_extractor.get_modalities(
        modalities_encoder_trainer.model, modalities_extraction_loader)
    modalities_grid = modalities_extractor.get_modalities_grid_image(
        modalities)
    logger.add_image("modality_per_col", modalities_grid, 0)

    del modalities_encoder_trainer
    del modalities_extractor
    torch.cuda.empty_cache()

    print(
        "Finished extracting modalities, begin training the translation network...\n"
    )

    train_source_loader, train_ref_loader, test_source_loader, test_ref_loader = get_gan_loaders(
        config, modalities)
    gan_trainer = GANTrainer(config)
    gan_trainer.to(config['device'])

    global_it = gan_trainer.resume(config['logger']['checkpoint_dir'],
                                   config) if config['resume'] else 0
    while global_it < config['gan']["max_iter"]:
        for it, (source_data, ref_data) in enumerate(
                zip(train_source_loader, train_ref_loader)):
            with Timer("Elapsed time in update: %f"):
                d_acc = gan_trainer.dis_update(source_data, ref_data, config)
                g_acc = gan_trainer.gen_update(source_data, ref_data, config)

                torch.cuda.synchronize(config['device'])

                print('D acc: %.4f\t G acc: %.4f' % (d_acc, g_acc))
                print("Iteration: {curr_iter}/{total_iter}".format(
                    curr_iter=str(global_it + 1).zfill(8),
                    total_iter=str(config['gan']['max_iter']).zfill(8)))

            # Save images for evaluation
            if global_it % config['logger']['eval_every'] == 0:
                with torch.no_grad():
                    for (val_source_data,
                         val_ref_data) in zip(train_source_loader,
                                              train_ref_loader):
                        val_image_outputs = gan_trainer.test(
                            val_source_data, val_ref_data)
                        write_1images(val_image_outputs,
                                      config['logger']['image_dir'],
                                      'train_{iter}'.format(iter=global_it))
                        save_image_tb(val_image_outputs, "train", global_it,
                                      logger)
                        break
                    for (test_source_data,
                         test_ref_data) in zip(test_source_loader,
                                               test_ref_loader):
                        test_image_outputs = gan_trainer.test(
                            test_source_data, test_ref_data)
                        write_1images(test_image_outputs,
                                      config['logger']['image_dir'],
                                      'test_{iter}'.format(iter=global_it))
                        save_image_tb(test_image_outputs, "test", global_it,
                                      logger)
                        break

            # Log losses
            if global_it % config['logger']['log_loss'] == 0:
                write_loss(global_it, gan_trainer, logger)

            # Save checkpoint
            if global_it % config['logger']['checkpoint_gan_every'] == 0:
                gan_trainer.save(config['logger']['checkpoint_dir'], global_it)

            global_it += 1

    print("Finished training!")