示例#1
0
def main():
    parser = utils.prepare_parser()
    parser = utils.add_dgp_parser(parser)

    config = vars(parser.parse_args())
    utils.dgp_update_config(config)

    #Print parameters
    cat = [
        'model', 'dgp_mode', 'list_file', 'exp_path', 'root_dir', 'resolution',
        'random_G', 'update_G', 'custom_mask'
    ]
    for key, val in config.items():
        if key in cat:
            print(key, ":", str(val))

    if config['custom_mask']:
        config['mask_path'] = '../data/input/' + config['mask_path']
        print('mask_path :', config['mask_path'])

    rank = 0
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    if config['dist']:
        rank, world_size = dist_init(config['port'])

    # Seed RNG
    utils.seed_rng(rank + config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # train
    trainer = Trainer(config)
    trainer.run()
示例#2
0
def main(args):
    with open(args.config) as f:
        config = yaml.load(f)

    for k, v in config.items():
        setattr(args, k, v)

    # exp path
    if not hasattr(args, 'exp_path'):
        args.exp_path = os.path.dirname(args.config)

    # dist init
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    dist_init(args.launcher, backend='nccl')

    # train
    trainer = Trainer(args)
    trainer.run()
示例#3
0
def main():
    parser = utils.prepare_parser()
    parser = utils.add_dgp_parser(parser)
    config = vars(parser.parse_args())
    utils.dgp_update_config(config)
    print(config)

    rank = 0
    if mp.get_start_method(allow_none=True) != 'spawn':
        mp.set_start_method('spawn', force=True)
    if config['dist']:
        rank, world_size = dist_init(config['port'])

    # Seed RNG
    utils.seed_rng(rank + config['seed'])

    # Setup cudnn.benchmark for free speed
    torch.backends.cudnn.benchmark = True

    # train
    trainer = Trainer(config)
    trainer.run()
示例#4
0
def main(config, rank, world_size, gpu_id, port, kwargs):
    torch.backends.cudnn.benchmark = True

    conf = parse_config_or_kwargs(config, **kwargs)

    # --------- multi machine train set up --------------
    if conf['train_local'] == 1:
        host_addr = 'localhost'
        conf['rank'] = rank
        conf['local_rank'] = gpu_id  # specify the local gpu id
        conf['world_size'] = world_size
        dist_init(host_addr, conf['rank'], conf['local_rank'],
                  conf['world_size'], port)
    else:
        host_addr = getoneNode()
        conf['rank'] = int(os.environ['SLURM_PROCID'])
        conf['local_rank'] = int(os.environ['SLURM_LOCALID'])
        conf['world_size'] = int(os.environ['SLURM_NTASKS'])
        dist_init(host_addr, conf['rank'], conf['local_rank'],
                  conf['world_size'], '2' + os.environ['SLURM_JOBID'][-4:])
        gpu_id = conf['local_rank']
    # --------- multi machine train set up --------------

    # setup logger
    if conf['rank'] == 0:
        check_dir(conf['exp_dir'])
        logger = get_logger_2(os.path.join(conf['exp_dir'], 'train.log'),
                              "[ %(asctime)s ] %(message)s")
    dist.barrier()  # let the rank 0 mkdir first
    if conf['rank'] != 0:
        logger = get_logger_2(os.path.join(conf['exp_dir'], 'train.log'),
                              "[ %(asctime)s ] %(message)s")

    logger.info("Rank: {}/{}, local rank:{} is running".format(
        conf['rank'], conf['world_size'], conf['rank']))

    # write the config file to the exp_dir
    if conf['rank'] == 0:
        store_path = os.path.join(conf['exp_dir'], 'config.yaml')
        store_yaml(config, store_path, **kwargs)

    cuda_id = 'cuda:' + str(gpu_id)
    conf['device'] = torch.device(
        cuda_id if torch.cuda.is_available() else 'cpu')

    model_dir = os.path.join(conf['exp_dir'], 'models')
    if conf['rank'] == 0:
        check_dir(model_dir)
    conf['checkpoint_format'] = os.path.join(model_dir, '{}.th')

    set_seed(666 + conf['rank'])

    if 'R' in conf['model_type']:
        model = eval(conf['model_type'])(base_ch_num=conf['base_ch_num'],
                                         t=conf['t'])
    else:
        model = eval(conf['model_type'])(base_ch_num=conf['base_ch_num'])
    model = model.to(conf['device'])
    model = DDP(model,
                device_ids=[conf['local_rank']],
                output_device=conf['local_rank'])
    optimizer = optim.Adam(model.parameters(),
                           lr=conf['lr'],
                           betas=(0.5, 0.99))

    if conf['rank'] == 0:
        num_params = sum(param.numel() for param in model.parameters())
        logger.info("Model type: {} Base channel num:{}".format(
            conf['model_type'], conf['base_ch_num']))
        logger.info("Number of parameters: {:.4f}M".format(1.0 * num_params /
                                                           1e6))
        logger.info(optimizer)

    train_set = ImageFolder(root=conf['root'],
                            mode='train',
                            augmentation_prob=conf['aug_prob'],
                            crop_size_min=conf['crop_size_min'],
                            crop_size_max=conf['crop_size_max'],
                            data_num=conf['data_num'],
                            gauss_size=conf['gauss_size'],
                            data_aug_list=conf['aug_list'])
    train_loader = DataLoader(dataset=train_set,
                              batch_size=conf['batch_size'],
                              shuffle=conf['shuffle'],
                              num_workers=conf['num_workers'])

    dev_set = ImageFolder(root=conf['root'],
                          mode='train',
                          augmentation_prob=0.0)
    dev_loader = DataLoader(dataset=dev_set,
                            batch_size=5,
                            shuffle=False,
                            num_workers=1)

    valid_set = ImageFolder(root=conf['root'], mode='valid')
    valid_loader = DataLoader(dataset=valid_set,
                              batch_size=5,
                              shuffle=False,
                              num_workers=1)

    test_set = ImageFolder(root=conf['root'], mode='test')
    test_loader = DataLoader(dataset=test_set,
                             batch_size=5,
                             shuffle=False,
                             num_workers=1)

    dist.barrier()  # synchronize here
    train(model, train_loader, test_loader, dev_loader, optimizer, conf,
          logger)