Ejemplo n.º 1
0
def train(args):
    # setting for logging
    if not os.path.exists(args.log):
        os.mkdir(args.log)
    logger = logging.getLogger()
    logging.basicConfig(level=logging.INFO)
    log_path = os.path.join(args.log, 'log')
    file_handler = logging.FileHandler(log_path)
    fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    logger.info('Arguments...')
    for arg, val in vars(args).items():
        logger.info('{:>10} -----> {}'.format(arg, val))

    x, y = gen_synthetic_data(DIM, DIM_EMB, NUM)
    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2)
    valid_x, test_x, valid_y, test_y = train_test_split(test_x, test_y, test_size=0.5)

    gen = Generator(DIM_EMB)
    dis = Discriminator(DIM_EMB)

    gen_opt = optimizers.Adam()
    dis_opt = optimizers.Adam()

    gen_opt.setup(gen)
    dis_opt.setup(dis)

    trainer = GANTrainer((gen, dis), (gen_opt, dis_opt), logger, (valid_x, valid_y), args.epoch)
    trainer.fit(train_x, train_y)
Ejemplo n.º 2
0
    def fit(self, data):
        """Fit the model to the given data.

        Args:
            data(pandas.DataFrame): dataset to fit the model.

        Returns:
            None

        """
        self.preprocessor = Preprocessor(
            continuous_columns=self.continuous_columns)

        data = self.preprocessor.fit_transform(data)
        self.metadata = self.preprocessor.metadata
        dataflow = TGANDataFlow(data, self.metadata)
        batch_data = BatchData(dataflow, self.batch_size)
        input_queue = QueueInput(batch_data)

        self.model = self.get_model(training=True)

        if self.trainer == 'GANTrainer':
            trainer = GANTrainer(model=self.model, input_queue=input_queue)
        elif self.trainer == 'SeparateGANTrainer':
            trainer = SeparateGANTrainer(model=self.model,
                                         input_queue=input_queue)
        else:
            raise ValueError(
                'Incorrect trainer name. Use GANTrainer or SeparateGANTrainer')

        # trainer = SeparateGANTrainer(model=self.model, input_queue=input_queue)

        self.restore_path = os.path.join(self.model_dir, 'checkpoint')

        if os.path.isfile(self.restore_path) and self.restore_session:
            session_init = SaverRestore(self.restore_path)
            with open(os.path.join(self.log_dir, 'stats.json')) as f:
                starting_epoch = json.load(f)[-1]['epoch_num'] + 1
        else:
            session_init = None
            starting_epoch = 1

        action = 'k' if self.restore_session else None
        logger.set_logger_dir(self.log_dir, action=action)

        callbacks = []
        if self.save_checkpoints:
            callbacks.append(ModelSaver(checkpoint_dir=self.model_dir))

        trainer.train_with_defaults(callbacks=callbacks,
                                    steps_per_epoch=self.steps_per_epoch,
                                    max_epoch=self.max_epoch,
                                    session_init=session_init,
                                    starting_epoch=starting_epoch)

        self.prepare_sampling()
Ejemplo n.º 3
0
def main(config, resume):
    train_logger = Logger()

    # setup data_loader instances
    train_data_loader = get_instance(module_data, 'train_data_loader', config)
    valid_data_loader = get_instance(module_data, 'valid_data_loader', config)

    # build model architecture
    model = {}
    model["generator"] = get_instance(module_g_arch, 'generator_arch', config)
    print(model["generator"])
    model["local_discriminator"] = get_instance(module_dl_arch,
                                                'local_discriminator_arch',
                                                config)
    print(model["local_discriminator"])

    # get function handles of loss and metrics
    loss = {}
    loss["vanilla_gan"] = torch.nn.BCELoss()
    loss["lsgan"] = torch.nn.MSELoss()
    loss["ce"] = module_loss.cross_entropy2d
    loss["pg"] = module_loss.PG_Loss()
    loss["mask_ce"] = module_loss.Masked_CrossEntropy()
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    optimizer = {}
    generator_trainable_params = filter(lambda p: p.requires_grad,
                                        model["generator"].parameters())
    local_discriminator_trainable_params = filter(
        lambda p: p.requires_grad, model["local_discriminator"].parameters())
    optimizer["generator"] = get_instance(torch.optim, 'generator_optimizer',
                                          config, generator_trainable_params)
    optimizer["local_discriminator"] = get_instance(
        torch.optim, 'discriminator_optimizer', config,
        local_discriminator_trainable_params)
    # lr_scheduler = None # get_instance(torch.optim.lr_scheduler, 'lr_scheduler', config, optimizer)

    trainer = GANTrainer(model,
                         optimizer,
                         loss,
                         metrics,
                         resume=resume,
                         config=config,
                         data_loader=train_data_loader,
                         valid_data_loader=valid_data_loader,
                         train_logger=train_logger)
    print("pretrain models")
    trainer.pre_train()
    print("training")
    trainer.train()
    evaluator = UnetEvaluator(trainer.generator, trainer.config)
    evaluator.evaluate()
Ejemplo n.º 4
0
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              'train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/val_captions_custom.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
Ejemplo n.º 5
0
            if exc.errno == errno.EEXIST and os.path.isdir(path):
                pass
            else:
                raise

        copyfile(sys.argv[0], output_dir + "/" + sys.argv[0])
        copyfile("trainer.py", output_dir + "/" + "trainer.py")
        copyfile("model.py", output_dir + "/" + "model.py")
        copyfile("miscc/utils.py", output_dir + "/" + "utils.py")
        copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py")
        copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml")

        imsize=64

        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, split="train", imsize=imsize, transform=img_transform,
                              crop=True)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader)
    else:
        datapath = os.path.join(cfg.DATA_DIR, "test")
        algo = GANTrainer(output_dir)
        algo.sample(datapath, num_samples=25, draw_bbox=True)
Ejemplo n.º 6
0
                vid.append(image_transform(im))
            except RuntimeError as err:
                print(err, "/", im.shape)
                raise

        vid = torch.stack(vid).permute(1, 0, 2, 3)

        return vid
    
    video_transforms = functools.partial(video_transform, image_transform=image_transforms)

    storydataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN, True)
    imagedataset = data.ImageDataset(dir_path, image_transforms, cfg.VIDEO_LEN, True)
    testdataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN, False)

    imageloader = torch.utils.data.DataLoader(
        imagedataset, batch_size=cfg.TRAIN.IM_BATCH_SIZE * num_gpu,
        drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

    storyloader = torch.utils.data.DataLoader(
        storydataset, batch_size=cfg.TRAIN.ST_BATCH_SIZE * num_gpu,
        drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))


    testloader = torch.utils.data.DataLoader(
        testdataset, batch_size=24 * num_gpu,
        drop_last=True, shuffle=False, num_workers=int(cfg.WORKERS))

    algo = GANTrainer(output_dir, cfg.ST_WEIGHT, test_sample_save_dir)
    algo.train(imageloader, storyloader, testloader)
Ejemplo n.º 7
0
    model_name = args.model
    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    num_gpu = len(cfg.GPU_ID.split(','))
    n_channels = 3

    image_transforms = transforms.Compose([
        PIL.Image.fromarray,
        transforms.Resize((cfg.IMSIZE, cfg.IMSIZE)),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        lambda x: x[:n_channels, ::],
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    video_transforms = functools.partial(video_transform,
                                         image_transform=image_transforms)
    testdataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN,
                                    False)
    testloader = torch.utils.data.DataLoader(testdataset,
                                             batch_size=24,
                                             drop_last=True,
                                             shuffle=False,
                                             num_workers=int(cfg.WORKERS))
    output_dir = './output/%s_%s/' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME)
    test_sample_save_dir = output_dir + 'test/'
    trainer = GANTrainer(output_dir, cfg, cfg.ST_WEIGHT, test_sample_save_dir,
                         cfg.TENSORBOARD)
    trainer.evaluate(model_name, testloader, args.output)
Ejemplo n.º 8
0
def main(args):
    np.random.seed(0)
    torch.manual_seed(0)

    with open('config.yaml', 'r') as file:
        stream = file.read()
        config_dict = yaml.safe_load(stream)
        config = mapper(**config_dict)

    disc_model = Discriminator(input_shape=(config.data.channels,
                                            config.data.hr_height,
                                            config.data.hr_width))
    gen_model = GeneratorResNet()
    feature_extractor_model = FeatureExtractor()
    plt.ion()

    if config.distributed:
        disc_model.to(device)
        disc_model = nn.parallel.DistributedDataParallel(disc_model)
        gen_model.to(device)
        gen_model = nn.parallel.DistributedDataParallel(gen_model)
        feature_extractor_model.to(device)
        feature_extractor_model = nn.parallel.DistributedDataParallel(
            feature_extractor_model)
    elif config.gpu:
        # disc_model = nn.DataParallel(disc_model).to(device)
        # gen_model = nn.DataParallel(gen_model).to(device)
        # feature_extractor_model = nn.DataParallel(feature_extractor_model).to(device)
        disc_model = disc_model.to(device)
        gen_model = gen_model.to(device)
        feature_extractor_model = feature_extractor_model.to(device)
    else:
        return

    train_dataset = ImageDataset(config.data.path,
                                 hr_shape=(config.data.hr_height,
                                           config.data.hr_width),
                                 lr_shape=(config.data.lr_height,
                                           config.data.lr_width))
    test_dataset = ImageDataset(config.data.path,
                                hr_shape=(config.data.hr_height,
                                          config.data.hr_width),
                                lr_shape=(config.data.lr_height,
                                          config.data.lr_width))

    if config.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.data.batch_size,
        shuffle=config.data.shuffle,
        num_workers=config.data.workers,
        pin_memory=config.data.pin_memory,
        sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=config.data.batch_size,
                                             shuffle=config.data.shuffle,
                                             num_workers=config.data.workers,
                                             pin_memory=config.data.pin_memory)

    if args.train:
        # trainer settings
        trainer = GANTrainer(config.train, train_loader,
                             (disc_model, gen_model, feature_extractor_model))
        criterion = nn.MSELoss().to(device)
        disc_optimizer = torch.optim.Adam(disc_model.parameters(),
                                          config.train.hyperparameters.lr)
        gen_optimizer = torch.optim.Adam(gen_model.parameters(),
                                         config.train.hyperparameters.lr)
        fe_optimizer = torch.optim.Adam(feature_extractor_model.parameters(),
                                        config.train.hyperparameters.lr)

        trainer.setCriterion(criterion)
        trainer.setDiscOptimizer(disc_optimizer)
        trainer.setGenOptimizer(gen_optimizer)
        trainer.setFEOptimizer(fe_optimizer)

        # evaluator settings
        evaluator = GANEvaluator(
            config.evaluate, val_loader,
            (disc_model, gen_model, feature_extractor_model))
        # optimizer = torch.optim.Adam(disc_model.parameters(), lr=config.evaluate.hyperparameters.lr,
        # 	weight_decay=config.evaluate.hyperparameters.weight_decay)
        evaluator.setCriterion(criterion)

    if args.test:
        pass

    # Turn on benchmark if the input sizes don't vary
    # It is used to find best way to run models on your machine
    cudnn.benchmark = True
    start_epoch = 0
    best_precision = 0

    # optionally resume from a checkpoint
    if config.train.resume:
        [start_epoch,
         best_precision] = trainer.load_saved_checkpoint(checkpoint=None)

    # change value to test.hyperparameters on testing
    for epoch in range(start_epoch, config.train.hyperparameters.total_epochs):
        if config.distributed:
            train_sampler.set_epoch(epoch)

        if args.train:
            trainer.adjust_learning_rate(epoch)
            trainer.train(epoch)
            prec1 = evaluator.evaluate(epoch)

        if args.test:
            pass

        # remember best prec@1 and save checkpoint
        if args.train:
            is_best = prec1 > best_precision
            best_precision = max(prec1, best_precision)
            trainer.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': disc_model.state_dict(),
                    'best_precision': best_precision,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                checkpoint=None)
Ejemplo n.º 9
0
    cfg.TRAIN.DISCRIMINATOR_LR = args.discriminator_lr
    cfg.TRAIN.GENERATOR_LR = args.generator_lr
    cfg.TRAIN.PRETRAINED_EPOCH = args.pretrained_epoch
    cfg.TRAIN.PRETRAINED_MODEL = args.pretrained_model
    cfg.STAGE1_G = args.stage1_g
    cfg.NET_G = args.net_g
    cfg.NET_D = args.net_d
    cfg.WORKERS = args.workers

    print('Using config:')
    pprint.pprint(cfg)

    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    if cfg.CUDA:
        torch.cuda.manual_seed_all(args.manualSeed)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))

    datapath = '/.local/AttnGAN/data/FashionSynthesis/test/embeddings/final.npy'
    algo = GANTrainer(output_dir)
    # algo.sample(datapath, cfg.STAGE)
    algo.sample2(datapath, cfg.STAGE)
Ejemplo n.º 10
0
        elif cfg.STAGE == 2:
            resize = 268
            imsize = 256

        img_transform = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              cfg.IMG_DIR,
                              split="train",
                              imsize=imsize,
                              transform=img_transform,
                              crop=True,
                              stage=cfg.STAGE)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, num_samples=25, stage=cfg.STAGE, draw_bbox=True)
Ejemplo n.º 11
0
def main(args):

    # preparation
    if not os.path.exists(args.exp_dir):
        os.makedirs(args.exp_dir)
    config_logging(os.path.join(args.exp_dir, "%s.log" % args.exp_name))
    log.info("Experiment %s" % (args.exp_name))
    log.info("Receive config %s" % (args.__str__()))
    log.info("Start creating tasks")
    pretrain_task = [get_task(taskname, args) for taskname in args.pretrain_task]
    finetune_tasks = [get_task(taskname, args) for taskname in args.finetune_tasks]
    log.info("Start loading data")

    if args.image_pretrain_obj != "none" or args.view_pretrain_obj != "none":
        for task in pretrain_task:
            task.load_data()
    for task in finetune_tasks:
        task.load_data()

    log.info("Start creating models")
    if len(pretrain_task):
        if args.image_pretrain_obj != "none":
            image_ssl_model = get_model("image_ssl", args)
            log.info("Loaded image ssl model")

        if args.view_pretrain_obj != "none":
            view_ssl_model = get_model("view_ssl", args)
            log.info("Loaded view ssl model")

    if args.finetune_obj != "none": 
        sup_model = get_model("sup", args)
        log.info("Loaded supervised model")

    #if args.load_ckpt != "none":
    #    load_model(model, pretrain_complete_ckpt)

    # pretrain
    if len(pretrain_task):
        if args.image_pretrain_obj != "none":
            image_ssl_model.to(args.device)
            pretrain = Trainer("pretrain", image_ssl_model, pretrain_task[0], args)
            pretrain.train()
            image_pretrain_complete_ckpt = os.path.join(
                args.exp_dir, "image_pretrain_%s_complete.pth" % pretrain_task[0].name
            )
            save_model(image_pretrain_complete_ckpt, image_ssl_model)
        else:
            if args.imagessl_load_ckpt:
                image_pretrain_complete_ckpt = args.imagessl_load_ckpt

        if args.view_pretrain_obj != "none":
            view_ssl_model.to(args.device)
            pretrain = Trainer("pretrain", view_ssl_model, pretrain_task[0], args)
            pretrain.train()
            view_pretrain_complete_ckpt = os.path.join(
                args.exp_dir, "view_pretrain_%s_complete.pth" % pretrain_task[0].name
            )
            save_model(view_pretrain_complete_ckpt, view_ssl_model)
        else:
            if args.viewssl_load_ckpt:
                view_pretrain_complete_ckpt = args.viewssl_load_ckpt

    # finetune and test
    for task in finetune_tasks:
        if args.imagessl_load_ckpt is not "none":
            pretrained_dict = torch.load(image_pretrain_complete_ckpt,map_location=torch.device('cpu'))
            model_dict = sup_model.state_dict()
            tdict = model_dict.copy()
            # print(sup_model.image_network.parameters())
            # print((sup_model.image_network[1].weight.data))
            # wtv = sup_model.image_network[0].weight.data
            # print(tdict.items()==model_dict.items())
            # print(type(tdict),type(model_dict))


            # print(model_dict.keys())
            # print("\n\n\n")

            
            pretrained_dict = {k.replace("patch","image"): v for k, v in pretrained_dict.items() if k.replace("patch","image") in model_dict}
            # print(pretrained_dict.keys())
            # print("\n\n\n")

            model_dict.update(pretrained_dict)
            sup_model.load_state_dict(model_dict)
            # print(type(tdict),type(model_dict))
            # print(sup_model.image_network[1].weight.data)
            # print((tdict.items()==model_dict.items()).all())
            
            


       
        if "adv" in args.finetune_obj:
            # print(type(sup_model))
            sup_model["generator"].to(args.device)
            sup_model["discriminator"].to(args.device)
            finetune = GANTrainer("finetune", sup_model, task, args)
        else:
            sup_model.to(args.device)
            finetune = Trainer("finetune", sup_model, task, args)

        finetune.train()
        finetune.eval("test")
        if "adv" in args.finetune_obj:
            finetune_generator_complete_ckpt = os.path.join(
                    args.exp_dir, "finetune_%s_generator_complete.pth" % task.name
                )

            save_model(finetune_generator_complete_ckpt, sup_model["generator"])

            finetune_discriminator_complete_ckpt = os.path.join(
                    args.exp_dir, "finetune_%s_discriminator_complete.pth" % task.name
                )

            save_model(finetune_discriminator_complete_ckpt, sup_model["discriminator"])
        
        else:
            finetune_complete_ckpt = os.path.join(
                    args.exp_dir, "finetune_%s_complete.pth" % task.name
                )

            save_model(finetune_complete_ckpt, sup_model)
        

    # evaluate
    # TODO: evaluate result on test split, write prediction for leaderboard submission (for dataset
    # without test labels)
    log.info("Done")
    return
Ejemplo n.º 12
0
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    if cfg.CUDA:
        torch.cuda.manual_seed_all(args.manualSeed)
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, 'train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath= '%s/test/val_captions.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
Ejemplo n.º 13
0
def main(gpu_id, data_dir, manual_seed, cuda, train_flag, image_size,
         batch_size, workers, stage, dataset_name, config_name, max_epoch,
         snapshot_interval, net_g, net_d, z_dim, generator_lr,
         discriminator_lr, lr_decay_epoch, coef_kl, stage1_g, embedding_type,
         condition_dim, df_dim, gf_dim, res_num, text_dim, regularizer):

    if manual_seed is None:
        manual_seed = random.randint(1, 10000)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    if cuda:
        torch.cuda.manual_seed_all(manual_seed)
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % (dataset_name, config_name, timestamp)

    num_gpu = len(gpu_id.split(','))
    if train_flag:
        image_transform = transforms.Compose([
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(data_dir,
                              'train',
                              imsize=image_size,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size *
                                                 num_gpu,
                                                 drop_last=True,
                                                 shuffle=True,
                                                 num_workers=int(workers))

        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.train(dataloader, stage, text_dim, gf_dim, condition_dim, z_dim,
                   df_dim, res_num)

    elif dataset_name == 'birds' and train_flag is False:
        image_transform = transforms.Compose([
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(data_dir,
                              'train',
                              imsize=image_size,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size *
                                                 num_gpu,
                                                 drop_last=True,
                                                 shuffle=True,
                                                 num_workers=int(workers))

        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.birds_eval(dataloader, stage)

    else:
        datapath = '%s/test/val_captions.t7' % (data_dir)
        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.sample(datapath, stage)
Ejemplo n.º 14
0
        mkdir_p(os.path.join(output_dir, 'model_reserve'))
    else:
        split = 'train'
        batch_size = cfg.TRAIN.BATCH_SIZE * num_gpu
        shuffle_flag = True

    Dataset = choose_dataset(cfg.DATASET_NAME)
    dataset = Dataset(cfg.DATA_DIR, split, imsize=cfg.IMSIZE)

    # Note the batchsize setting is here
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size,
        drop_last=True, shuffle=shuffle_flag, num_workers=int(cfg.WORKERS))

    # Initialize the main class which includes the training and evaluation
    algo = GANTrainer(output_dir, cfg_path=args.cfg_file)

    if cfg.TRAIN.FLAG:
        algo.train(dataloader)

    elif args.eval:
        date_str = datetime.datetime.now().strftime('%b-%d-%I%M%p-%G')

        if args.FID_eval:
            '''Do FID evaluations.'''
            f = open(os.path.join(output_dir, 'all_FID_eval.txt'), 'a')
            for net_G_name in net_G_names:
                cfg.NET_G = net_G_name
                algo.sample(dataloader, eval_name='eval',
                        eval_num=args.eval_num)
                fid_score_now = \
Ejemplo n.º 15
0
        dataset = TextImageDataset(data_dir=cfg.DATA_DIR,
                                   ann_file=cfg.ANN_FILE,
                                   imsize=cfg.IMSIZE,
                                   emb_model=cfg.EMB_MODEL,
                                   transform=image_transform,
                                   vocab_file=vocab)

        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            collate_fn=collate_fn,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir,
                          cap_model,
                          vocab,
                          eval_utils,
                          my_resnet,
                          dataset.word2idx,
                          dataset.emb,
                          dataset.idx2word,
                          vocab_cap=vocab_cap,
                          eval_kwargs=vars(opt))
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/val_captions.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
Ejemplo n.º 16
0
                                        video_transforms,
                                        is_train=False)

        testloader = torch.utils.data.DataLoader(
            testdataset,
            batch_size=cfg.TRAIN.ST_BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=False,
            num_workers=int(cfg.WORKERS))

        if args.eval_fid:
            algo = Infer(output_dir, 1.0)
            algo.eval_fid2(testloader, video_transforms, image_transforms)

        elif args.eval_fvd:
            algo = Infer(output_dir, 1.0)
            algo.eval_fvd(imageloader, storyloader, testloader, cfg.STAGE)

        elif args.load_ckpt != None:
            # For inference training result
            algo = Infer(output_dir, 1.0, args.load_ckpt)
            algo.inference(imageloader, storyloader, testloader, cfg.STAGE)
        else:
            # For training model
            algo = GANTrainer(output_dir, args, ratio=1.0)
            algo.train(imageloader, storyloader, testloader, cfg.STAGE)
    else:
        datapath = '%s/test/val_captions.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
Ejemplo n.º 17
0
        torch.cuda.manual_seed_all(args.manualSeed)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              'train',
                              embedding_type='lstm',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)