Example #1
0
def train(process_id, CFG):
    if 'GQN' in CFG.arch:
        from models.baseline_gqn import GQN as ScnModel
        print(" --- Arch: GQN ---")
    elif 'IODINE' in CFG.arch:
        from models.baseline_iodine import IODINE as ScnModel
        print(" --- Arch: IODINE ---")
    elif 'MulMON' in CFG.arch:
        from models.mulmon import MulMON as ScnModel
        print(" --- Arch: MulMON ---")
    else:
        raise NotImplementedError

    rank = CFG.nrank * CFG.gpus + process_id
    gpu = process_id + CFG.gpu_start  # e.g. gpus=2, gpu_start=1 means using gpu [0+1, 1+1] = [1, 2].

    distributed.init_process_group(backend='nccl',
                                   init_method='env://',
                                   world_size=CFG.world_size,
                                   rank=rank)

    if CFG.seed is None:
        CFG.seed = random.randint(0, 1000000)
    set_random_seed(CFG.seed)

    # Create the model
    scn_model = ScnModel(CFG)
    torch.cuda.set_device(gpu)

    if CFG.resume_epoch is not None:
        state_dict = load_trained_mp(CFG.resume_path)
        scn_model.load_state_dict(state_dict, strict=True)

    scn_model.cuda(gpu)
    params_to_update = get_trainable_params(scn_model)

    if CFG.optimiser == 'RMSprop':
        optimiser = torch.optim.RMSprop(params_to_update,
                                        lr=CFG.lr_rate,
                                        weight_decay=CFG.weight_decay)
        lr_scheduler = None
    else:
        optimiser = torch.optim.Adam(params_to_update,
                                     lr=CFG.lr_rate,
                                     weight_decay=CFG.weight_decay)
        lr_scheduler = AnnealingStepLR(optimiser,
                                       mu_i=CFG.lr_rate,
                                       mu_f=0.1 * CFG.lr_rate,
                                       n=1.0e6)

    scn_model = nn.parallel.DistributedDataParallel(scn_model,
                                                    device_ids=[gpu])

    if 'gqn' in CFG.DATA_TYPE:
        from data_loader.getGqnH5 import distributed_loader
    elif 'clevr' in CFG.DATA_TYPE:
        from data_loader.getClevrMV import distributed_loader
    else:
        raise NotImplementedError

    train_dataset = distributed_loader(CFG.DATA_ROOT,
                                       CFG.train_data_filename,
                                       num_slots=CFG.num_slots,
                                       use_bg=CFG.use_bg)
    val_dataset = distributed_loader(CFG.DATA_ROOT,
                                     CFG.test_data_filename,
                                     num_slots=CFG.num_slots,
                                     use_bg=CFG.use_bg)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=CFG.world_size, rank=rank)
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_dataset, num_replicas=CFG.world_size, rank=rank)

    # get data Loader
    train_dl = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=CFG.batch_size,
                                           shuffle=False,
                                           num_workers=0,
                                           pin_memory=True,
                                           collate_fn=lambda x: tuple(zip(*x)),
                                           sampler=train_sampler)
    val_dl = torch.utils.data.DataLoader(dataset=train_dataset,
                                         batch_size=CFG.batch_size,
                                         shuffle=False,
                                         num_workers=0,
                                         pin_memory=True,
                                         collate_fn=lambda x: tuple(zip(*x)),
                                         sampler=val_sampler)
    trainer = ModelTrainer(model=scn_model,
                           loss=None,
                           metrics=None,
                           optimizer=optimiser,
                           step_per_epoch=CFG.step_per_epoch,
                           config=CFG,
                           train_data_loader=train_dl,
                           valid_data_loader=val_dl,
                           device=gpu,
                           lr_scheduler=lr_scheduler)
    # Start training session
    trainer.train()
Example #2
0
        stride_to_obs=2
        nf_to_obs=128
        nf_dec=128
        nf_z=3
        nf_v=5
    # else:
        # raise NotImplementedError
        
    # Define model
    model = JUMP(nt, stride_to_hidden, nf_to_hidden, nf_enc, stride_to_obs, nf_to_obs, nf_dec, nf_z, nf_v).to(device)

    if len(args.device_ids)>1:
        model = nn.DataParallel(model, device_ids=args.device_ids)
        
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-08)
    scheduler = AnnealingStepLR(optimizer, mu_i=5e-4, mu_f=5e-5, n=1.6e6)

    kwargs = {'num_workers':num_workers, 'pin_memory': True} if torch.cuda.is_available() else {}

    train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True, **kwargs)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, **kwargs)

    f_data_test, v_data_test = next(iter(test_loader))
    N = f_data_test.size(1)

    step = 0
    # Training Iterations
    for epoch in range(args.num_epoch):
        for t, (f_data, v_data) in enumerate(tqdm(train_loader)):
            model.train()
            
Example #3
0
    B = args.batch_size
    
    # Number of generative layers
    L =args.layers

    # Maximum number of training steps
    S_max = args.gradient_steps

    # Define model
    model = GQN(representation=args.representation, L=L, shared_core=args.shared_core).to(device)
    if args.continue_training == 'False':
        if len(args.device_ids) > 1:
            model = nn.DataParallel(model, device_ids=args.device_ids)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
        scheduler = AnnealingStepLR(optimizer, mu_i=5e-5, mu_f=5e-6, n=1.6e6)
        t_start = 0
    else:
        path = args.path_to_model

        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler = AnnealingStepLR(optimizer, mu_i=5e-4, mu_f=5e-5, n=1.6e6)
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        sigma = checkpoint['sigma']
        sigma_i = sigma
        t_start = checkpoint['time_step']
        model.train()
Example #4
0
def train(gpu_id, CFG):
    if 'GQN' in CFG.arch:
        from models.baseline_gqn import GQN as ScnModel
        print(" --- Arch: GQN ---")
    elif 'IODINE' in CFG.arch:
        from models.baseline_iodine import IODINE as ScnModel
        print(" --- Arch: IODINE ---")
    elif 'MulMON' in CFG.arch:
        from models.mulmon import MulMON as ScnModel
        print(" --- Arch: MulMON ---")
    else:
        raise NotImplementedError

    # Create the model
    scn_model = ScnModel(CFG)
    if CFG.resume_epoch is not None:
        state_dict = load_trained_mp(CFG.resume_path)
        scn_model.load_state_dict(state_dict, strict=True)
    params_to_update = get_trainable_params(scn_model)

    if CFG.optimiser == 'RMSprop':
        optimiser = torch.optim.RMSprop(params_to_update,
                                        lr=CFG.lr_rate,
                                        weight_decay=CFG.weight_decay)
        lr_scheduler = None
    else:
        optimiser = torch.optim.Adam(params_to_update,
                                     lr=CFG.lr_rate,
                                     weight_decay=CFG.weight_decay)
        lr_scheduler = AnnealingStepLR(optimiser,
                                       mu_i=CFG.lr_rate,
                                       mu_f=0.1 * CFG.lr_rate,
                                       n=1e6)

    if 'gqn' in CFG.DATA_TYPE:
        from data_loader.getGqnH5 import DataLoader
    elif 'clevr' in CFG.DATA_TYPE:
        from data_loader.getClevrMV import DataLoader
    else:
        raise NotImplementedError

    # get data Loader
    train_dl = DataLoader(CFG.DATA_ROOT,
                          CFG.train_data_filename,
                          batch_size=CFG.batch_size,
                          shuffle=True,
                          num_slots=CFG.num_slots,
                          use_bg=True)
    val_dl = DataLoader(CFG.DATA_ROOT,
                        CFG.test_data_filename,
                        batch_size=CFG.batch_size,
                        shuffle=True,
                        num_slots=CFG.num_slots,
                        use_bg=True)

    if CFG.seed is None:
        CFG.seed = random.randint(0, 1000000)
    set_random_seed(CFG.seed)

    trainer = ModelTrainer(model=scn_model,
                           loss=None,
                           metrics=None,
                           optimizer=optimiser,
                           step_per_epoch=CFG.step_per_epoch,
                           config=CFG,
                           train_data_loader=train_dl,
                           valid_data_loader=val_dl,
                           device=gpu_id,
                           lr_scheduler=lr_scheduler)
    # Start training session
    trainer.train()