def main(config, resume):
    train_logger = Logger()

    # setup data_loader instances
    train_data_loader = get_instance(module_data, 'train_data_loader', config)
    valid_data_loader = get_instance(module_data, 'valid_data_loader', config)

    # build model architecture
    model = {}
    model["generator"] = get_instance(module_g_arch, 'generator_arch', config)
    print(model["generator"])
    model["local_discriminator"] = get_instance(module_dl_arch,
                                                'local_discriminator_arch',
                                                config)
    print(model["local_discriminator"])

    # get function handles of loss and metrics
    loss = {}
    loss["vanilla_gan"] = torch.nn.BCELoss()
    loss["lsgan"] = torch.nn.MSELoss()
    loss["ce"] = module_loss.cross_entropy2d
    loss["pg"] = module_loss.PG_Loss()
    loss["mask_ce"] = module_loss.Masked_CrossEntropy()
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    optimizer = {}
    generator_trainable_params = filter(lambda p: p.requires_grad,
                                        model["generator"].parameters())
    local_discriminator_trainable_params = filter(
        lambda p: p.requires_grad, model["local_discriminator"].parameters())
    optimizer["generator"] = get_instance(torch.optim, 'generator_optimizer',
                                          config, generator_trainable_params)
    optimizer["local_discriminator"] = get_instance(
        torch.optim, 'discriminator_optimizer', config,
        local_discriminator_trainable_params)
    # lr_scheduler = None # get_instance(torch.optim.lr_scheduler, 'lr_scheduler', config, optimizer)

    trainer = GANTrainer(model,
                         optimizer,
                         loss,
                         metrics,
                         resume=resume,
                         config=config,
                         data_loader=train_data_loader,
                         valid_data_loader=valid_data_loader,
                         train_logger=train_logger)
    print("pretrain models")
    trainer.pre_train()
    print("training")
    trainer.train()
    evaluator = UnetEvaluator(trainer.generator, trainer.config)
    evaluator.evaluate()
Exemple #2
0
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(cfg.DATA_DIR,
                              'train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True,
            shuffle=True,
            num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath = '%s/test/val_captions_custom.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
Exemple #3
0
            if exc.errno == errno.EEXIST and os.path.isdir(path):
                pass
            else:
                raise

        copyfile(sys.argv[0], output_dir + "/" + sys.argv[0])
        copyfile("trainer.py", output_dir + "/" + "trainer.py")
        copyfile("model.py", output_dir + "/" + "model.py")
        copyfile("miscc/utils.py", output_dir + "/" + "utils.py")
        copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py")
        copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml")

        imsize=64

        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, split="train", imsize=imsize, transform=img_transform,
                              crop=True)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader)
    else:
        datapath = os.path.join(cfg.DATA_DIR, "test")
        algo = GANTrainer(output_dir)
        algo.sample(datapath, num_samples=25, draw_bbox=True)
Exemple #4
0
                vid.append(image_transform(im))
            except RuntimeError as err:
                print(err, "/", im.shape)
                raise

        vid = torch.stack(vid).permute(1, 0, 2, 3)

        return vid
    
    video_transforms = functools.partial(video_transform, image_transform=image_transforms)

    storydataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN, True)
    imagedataset = data.ImageDataset(dir_path, image_transforms, cfg.VIDEO_LEN, True)
    testdataset = data.StoryDataset(dir_path, video_transforms, cfg.VIDEO_LEN, False)

    imageloader = torch.utils.data.DataLoader(
        imagedataset, batch_size=cfg.TRAIN.IM_BATCH_SIZE * num_gpu,
        drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

    storyloader = torch.utils.data.DataLoader(
        storydataset, batch_size=cfg.TRAIN.ST_BATCH_SIZE * num_gpu,
        drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))


    testloader = torch.utils.data.DataLoader(
        testdataset, batch_size=24 * num_gpu,
        drop_last=True, shuffle=False, num_workers=int(cfg.WORKERS))

    algo = GANTrainer(output_dir, cfg.ST_WEIGHT, test_sample_save_dir)
    algo.train(imageloader, storyloader, testloader)
Exemple #5
0
def main(args):
    np.random.seed(0)
    torch.manual_seed(0)

    with open('config.yaml', 'r') as file:
        stream = file.read()
        config_dict = yaml.safe_load(stream)
        config = mapper(**config_dict)

    disc_model = Discriminator(input_shape=(config.data.channels,
                                            config.data.hr_height,
                                            config.data.hr_width))
    gen_model = GeneratorResNet()
    feature_extractor_model = FeatureExtractor()
    plt.ion()

    if config.distributed:
        disc_model.to(device)
        disc_model = nn.parallel.DistributedDataParallel(disc_model)
        gen_model.to(device)
        gen_model = nn.parallel.DistributedDataParallel(gen_model)
        feature_extractor_model.to(device)
        feature_extractor_model = nn.parallel.DistributedDataParallel(
            feature_extractor_model)
    elif config.gpu:
        # disc_model = nn.DataParallel(disc_model).to(device)
        # gen_model = nn.DataParallel(gen_model).to(device)
        # feature_extractor_model = nn.DataParallel(feature_extractor_model).to(device)
        disc_model = disc_model.to(device)
        gen_model = gen_model.to(device)
        feature_extractor_model = feature_extractor_model.to(device)
    else:
        return

    train_dataset = ImageDataset(config.data.path,
                                 hr_shape=(config.data.hr_height,
                                           config.data.hr_width),
                                 lr_shape=(config.data.lr_height,
                                           config.data.lr_width))
    test_dataset = ImageDataset(config.data.path,
                                hr_shape=(config.data.hr_height,
                                          config.data.hr_width),
                                lr_shape=(config.data.lr_height,
                                          config.data.lr_width))

    if config.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.data.batch_size,
        shuffle=config.data.shuffle,
        num_workers=config.data.workers,
        pin_memory=config.data.pin_memory,
        sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=config.data.batch_size,
                                             shuffle=config.data.shuffle,
                                             num_workers=config.data.workers,
                                             pin_memory=config.data.pin_memory)

    if args.train:
        # trainer settings
        trainer = GANTrainer(config.train, train_loader,
                             (disc_model, gen_model, feature_extractor_model))
        criterion = nn.MSELoss().to(device)
        disc_optimizer = torch.optim.Adam(disc_model.parameters(),
                                          config.train.hyperparameters.lr)
        gen_optimizer = torch.optim.Adam(gen_model.parameters(),
                                         config.train.hyperparameters.lr)
        fe_optimizer = torch.optim.Adam(feature_extractor_model.parameters(),
                                        config.train.hyperparameters.lr)

        trainer.setCriterion(criterion)
        trainer.setDiscOptimizer(disc_optimizer)
        trainer.setGenOptimizer(gen_optimizer)
        trainer.setFEOptimizer(fe_optimizer)

        # evaluator settings
        evaluator = GANEvaluator(
            config.evaluate, val_loader,
            (disc_model, gen_model, feature_extractor_model))
        # optimizer = torch.optim.Adam(disc_model.parameters(), lr=config.evaluate.hyperparameters.lr,
        # 	weight_decay=config.evaluate.hyperparameters.weight_decay)
        evaluator.setCriterion(criterion)

    if args.test:
        pass

    # Turn on benchmark if the input sizes don't vary
    # It is used to find best way to run models on your machine
    cudnn.benchmark = True
    start_epoch = 0
    best_precision = 0

    # optionally resume from a checkpoint
    if config.train.resume:
        [start_epoch,
         best_precision] = trainer.load_saved_checkpoint(checkpoint=None)

    # change value to test.hyperparameters on testing
    for epoch in range(start_epoch, config.train.hyperparameters.total_epochs):
        if config.distributed:
            train_sampler.set_epoch(epoch)

        if args.train:
            trainer.adjust_learning_rate(epoch)
            trainer.train(epoch)
            prec1 = evaluator.evaluate(epoch)

        if args.test:
            pass

        # remember best prec@1 and save checkpoint
        if args.train:
            is_best = prec1 > best_precision
            best_precision = max(prec1, best_precision)
            trainer.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': disc_model.state_dict(),
                    'best_precision': best_precision,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                checkpoint=None)
Exemple #6
0
def main(args):

    # preparation
    if not os.path.exists(args.exp_dir):
        os.makedirs(args.exp_dir)
    config_logging(os.path.join(args.exp_dir, "%s.log" % args.exp_name))
    log.info("Experiment %s" % (args.exp_name))
    log.info("Receive config %s" % (args.__str__()))
    log.info("Start creating tasks")
    pretrain_task = [get_task(taskname, args) for taskname in args.pretrain_task]
    finetune_tasks = [get_task(taskname, args) for taskname in args.finetune_tasks]
    log.info("Start loading data")

    if args.image_pretrain_obj != "none" or args.view_pretrain_obj != "none":
        for task in pretrain_task:
            task.load_data()
    for task in finetune_tasks:
        task.load_data()

    log.info("Start creating models")
    if len(pretrain_task):
        if args.image_pretrain_obj != "none":
            image_ssl_model = get_model("image_ssl", args)
            log.info("Loaded image ssl model")

        if args.view_pretrain_obj != "none":
            view_ssl_model = get_model("view_ssl", args)
            log.info("Loaded view ssl model")

    if args.finetune_obj != "none": 
        sup_model = get_model("sup", args)
        log.info("Loaded supervised model")

    #if args.load_ckpt != "none":
    #    load_model(model, pretrain_complete_ckpt)

    # pretrain
    if len(pretrain_task):
        if args.image_pretrain_obj != "none":
            image_ssl_model.to(args.device)
            pretrain = Trainer("pretrain", image_ssl_model, pretrain_task[0], args)
            pretrain.train()
            image_pretrain_complete_ckpt = os.path.join(
                args.exp_dir, "image_pretrain_%s_complete.pth" % pretrain_task[0].name
            )
            save_model(image_pretrain_complete_ckpt, image_ssl_model)
        else:
            if args.imagessl_load_ckpt:
                image_pretrain_complete_ckpt = args.imagessl_load_ckpt

        if args.view_pretrain_obj != "none":
            view_ssl_model.to(args.device)
            pretrain = Trainer("pretrain", view_ssl_model, pretrain_task[0], args)
            pretrain.train()
            view_pretrain_complete_ckpt = os.path.join(
                args.exp_dir, "view_pretrain_%s_complete.pth" % pretrain_task[0].name
            )
            save_model(view_pretrain_complete_ckpt, view_ssl_model)
        else:
            if args.viewssl_load_ckpt:
                view_pretrain_complete_ckpt = args.viewssl_load_ckpt

    # finetune and test
    for task in finetune_tasks:
        if args.imagessl_load_ckpt is not "none":
            pretrained_dict = torch.load(image_pretrain_complete_ckpt,map_location=torch.device('cpu'))
            model_dict = sup_model.state_dict()
            tdict = model_dict.copy()
            # print(sup_model.image_network.parameters())
            # print((sup_model.image_network[1].weight.data))
            # wtv = sup_model.image_network[0].weight.data
            # print(tdict.items()==model_dict.items())
            # print(type(tdict),type(model_dict))


            # print(model_dict.keys())
            # print("\n\n\n")

            
            pretrained_dict = {k.replace("patch","image"): v for k, v in pretrained_dict.items() if k.replace("patch","image") in model_dict}
            # print(pretrained_dict.keys())
            # print("\n\n\n")

            model_dict.update(pretrained_dict)
            sup_model.load_state_dict(model_dict)
            # print(type(tdict),type(model_dict))
            # print(sup_model.image_network[1].weight.data)
            # print((tdict.items()==model_dict.items()).all())
            
            


       
        if "adv" in args.finetune_obj:
            # print(type(sup_model))
            sup_model["generator"].to(args.device)
            sup_model["discriminator"].to(args.device)
            finetune = GANTrainer("finetune", sup_model, task, args)
        else:
            sup_model.to(args.device)
            finetune = Trainer("finetune", sup_model, task, args)

        finetune.train()
        finetune.eval("test")
        if "adv" in args.finetune_obj:
            finetune_generator_complete_ckpt = os.path.join(
                    args.exp_dir, "finetune_%s_generator_complete.pth" % task.name
                )

            save_model(finetune_generator_complete_ckpt, sup_model["generator"])

            finetune_discriminator_complete_ckpt = os.path.join(
                    args.exp_dir, "finetune_%s_discriminator_complete.pth" % task.name
                )

            save_model(finetune_discriminator_complete_ckpt, sup_model["discriminator"])
        
        else:
            finetune_complete_ckpt = os.path.join(
                    args.exp_dir, "finetune_%s_complete.pth" % task.name
                )

            save_model(finetune_complete_ckpt, sup_model)
        

    # evaluate
    # TODO: evaluate result on test split, write prediction for leaderboard submission (for dataset
    # without test labels)
    log.info("Done")
    return
Exemple #7
0
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)
    if cfg.CUDA:
        torch.cuda.manual_seed_all(args.manualSeed)
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % \
                 (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)

    num_gpu = len(cfg.GPU_ID.split(','))
    if cfg.TRAIN.FLAG:
        image_transform = transforms.Compose([
            transforms.RandomCrop(cfg.IMSIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset = TextDataset(cfg.DATA_DIR, 'train',
                              imsize=cfg.IMSIZE,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
            drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))

        algo = GANTrainer(output_dir)
        algo.train(dataloader, cfg.STAGE)
    else:
        datapath= '%s/test/val_captions.t7' % (cfg.DATA_DIR)
        algo = GANTrainer(output_dir)
        algo.sample(datapath, cfg.STAGE)
from utils import mkdir_p

manualSeed = random.randint(1, 10000)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
if CUDA:
    torch.cuda.manual_seed_all(manualSeed)
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '/content/output/%s_%s' % \
              ('102flowers', timestamp)

num_gpu = len(GPU_ID.split(','))
if TRAIN_FLAG:
    image_transform = transforms.Compose([
        transforms.RandomCrop(IMSIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    dataset = TextDataset(DATA_DIR, 'jpg', imsize=IMSIZE, transform=image_transform)
    assert dataset
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=TRAIN_BATCH_SIZE * num_gpu,
        drop_last=True, shuffle=True, num_workers=int(1))

from trainer import GANTrainer
algo = GANTrainer(output_dir)
algo.train(dataloader, STAGE)

def main(gpu_id, data_dir, manual_seed, cuda, train_flag, image_size,
         batch_size, workers, stage, dataset_name, config_name, max_epoch,
         snapshot_interval, net_g, net_d, z_dim, generator_lr,
         discriminator_lr, lr_decay_epoch, coef_kl, stage1_g, embedding_type,
         condition_dim, df_dim, gf_dim, res_num, text_dim, regularizer):

    if manual_seed is None:
        manual_seed = random.randint(1, 10000)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    if cuda:
        torch.cuda.manual_seed_all(manual_seed)
    now = datetime.datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    output_dir = '../output/%s_%s_%s' % (dataset_name, config_name, timestamp)

    num_gpu = len(gpu_id.split(','))
    if train_flag:
        image_transform = transforms.Compose([
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(data_dir,
                              'train',
                              imsize=image_size,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size *
                                                 num_gpu,
                                                 drop_last=True,
                                                 shuffle=True,
                                                 num_workers=int(workers))

        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.train(dataloader, stage, text_dim, gf_dim, condition_dim, z_dim,
                   df_dim, res_num)

    elif dataset_name == 'birds' and train_flag is False:
        image_transform = transforms.Compose([
            transforms.RandomCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = TextDataset(data_dir,
                              'train',
                              imsize=image_size,
                              transform=image_transform)
        assert dataset
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=batch_size *
                                                 num_gpu,
                                                 drop_last=True,
                                                 shuffle=True,
                                                 num_workers=int(workers))

        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.birds_eval(dataloader, stage)

    else:
        datapath = '%s/test/val_captions.t7' % (data_dir)
        algo = GANTrainer(output_dir, max_epoch, snapshot_interval, gpu_id,
                          batch_size, train_flag, net_g, net_d, cuda, stage1_g,
                          z_dim, generator_lr, discriminator_lr,
                          lr_decay_epoch, coef_kl, regularizer)
        algo.sample(datapath, stage)