Esempio n. 1
0
def main():
    torch.backends.cudnn.benchmark = True

    parser = common_utils.Parser()
    parser.add_parser('main', get_main_parser())
    parser.add_parser('coach', ConvRnnCoach.get_arg_parser())

    args = parser.parse()
    parser.log()

    options = args['main']

    if not os.path.exists(options.model_folder):
        os.makedirs(options.model_folder)
    logger_path = os.path.join(options.model_folder, 'train.log')
    if not options.dev:
        sys.stdout = common_utils.Logger(logger_path)

    if options.dev:
        options.train_dataset = options.train_dataset.replace('train.', 'dev.')
        options.val_dataset = options.val_dataset.replace('val.', 'dev.')

    print('Args:\n%s\n' % pprint.pformat(vars(options)))

    if options.gpu < 0:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:%d' % options.gpu)

    common_utils.set_all_seeds(options.seed)

    model_args = args['coach']
    if options.coach_type == 'onehot':
        model = ConvOneHotCoach(model_args, 0, options.max_instruction_span,
                                options.num_resource_bin).to(device)
    elif options.coach_type in ['rnn', 'bow']:
        model = ConvRnnCoach(model_args, 0, options.max_instruction_span,
                             options.coach_type,
                             options.num_resource_bin).to(device)
    elif options.coach_type == 'rnn_gen':
        model = RnnGenerator(model_args, 0, options.max_instruction_span,
                             options.num_resource_bin).to(device)

    print(model)

    train_dataset = CoachDataset(
        options.train_dataset,
        options.moving_avg_decay,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        model.inst_dict,
        options.max_instruction_span,
    )
    val_dataset = CoachDataset(
        options.val_dataset,
        options.moving_avg_decay,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        model.inst_dict,
        options.max_instruction_span,
    )
    eval_dataset = CoachDataset(options.val_dataset,
                                options.moving_avg_decay,
                                options.num_resource_bin,
                                options.resource_bin_size,
                                options.max_num_prev_cmds,
                                model.inst_dict,
                                options.max_instruction_span,
                                num_instructions=model.args.num_pos_inst)

    if not options.dev:
        compute_cache(train_dataset)
        compute_cache(val_dataset)
        compute_cache(eval_dataset)

    if options.optim == 'adamax':
        optimizer = torch.optim.Adamax(model.parameters(),
                                       lr=options.lr,
                                       betas=(options.beta1, options.beta2))
    elif options.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=options.lr,
                                     betas=(options.beta1, options.beta2))
    else:
        assert False, 'not supported'

    train_loader = DataLoader(
        train_dataset,
        options.batch_size,
        shuffle=True,
        num_workers=1,  # if options.dev else 10,
        pin_memory=(options.gpu >= 0))
    val_loader = DataLoader(
        val_dataset,
        options.batch_size,
        shuffle=False,
        num_workers=1,  # if options.dev else 10,
        pin_memory=(options.gpu >= 0))
    eval_loader = DataLoader(
        eval_dataset,
        options.batch_size,
        shuffle=False,
        num_workers=1,  #0 if options.dev else 10,
        pin_memory=(options.gpu >= 0))

    best_val_nll = float('inf')
    overfit_count = 0
    for epoch in range(1, options.epochs + 1):
        print('==========')
        train(model, device, optimizer, options.grad_clip, train_loader, epoch)
        with torch.no_grad(), common_utils.EvalMode(model):
            val_nll = evaluate(model, device, val_loader, epoch, 'val', False)
            eval_nll = evaluate(model, device, eval_loader, epoch, 'eval',
                                True)

        model_file = os.path.join(options.model_folder,
                                  'checkpoint%d.pt' % epoch)
        print('saving model to', model_file)
        model.save(model_file)

        if val_nll < best_val_nll:
            print('!!!New Best Model')
            overfit_count = 0
            best_val_nll = val_nll
            best_model_file = os.path.join(options.model_folder,
                                           'best_checkpoint.pt')
            print('saving best model to', best_model_file)
            model.save(best_model_file)
        else:
            overfit_count += 1
            if overfit_count == 2:
                break

    print('train DONE')
Esempio n. 2
0
    parser.add_argument("--actor_sync_freq", type=int, default=10)

    args = parser.parse_args()
    assert args.method in ["vdn", "iql"]
    return args


if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    args = parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    logger_path = os.path.join(args.save_dir, "train.log")
    sys.stdout = common_utils.Logger(logger_path)
    saver = common_utils.TopkSaver(args.save_dir, 5)

    common_utils.set_all_seeds(args.seed)
    pprint.pprint(vars(args))

    if args.method == "vdn":
        args.batchsize = int(np.round(args.batchsize / args.num_player))
        args.replay_buffer_size //= args.num_player
        args.burn_in_frames //= args.num_player

    explore_eps = utils.generate_explore_eps(args.act_base_eps,
                                             args.act_eps_alpha, args.num_eps)
    expected_eps = np.mean(explore_eps)
    print("explore eps:", explore_eps)
    print("avg explore eps:", np.mean(explore_eps))
Esempio n. 3
0
    # others
    parser.add_argument("--num_eval_game", type=int, default=10)
    parser.add_argument("--record_time", type=int, default=0)

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    args = parse_args()

    common_utils.set_all_seeds(args.seed)
    pprint.pprint(vars(args))

    sys.stdout = common_utils.Logger(os.path.join(args.save_dir, "train.log"))
    sys.stderr = common_utils.Logger(os.path.join(args.save_dir, "train.err"))

    num_action = create_atari.get_num_action(args.game)
    if args.algo == "r2d2":
        net_cons = lambda device: AtariLSTMNet(device, num_action)
        agent = R2D2Agent(
            net_cons,
            args.train_device,
            args.multi_step,
            args.gamma,
            args.eta,
            args.seq_len,
            args.seq_burn_in,
            args.same_hid,
        )
Esempio n. 4
0
def main():
    torch.backends.cudnn.benchmark = True

    parser = common_utils.Parser()
    parser.add_parser('main', get_main_parser())
    parser.add_parser('executor', Executor.get_arg_parser())
    args = parser.parse()
    parser.log()

    options = args['main']
    # print('Args:\n%s\n' % pprint.pformat(vars(options)))

    # option_map = parse_args()
    # options = option_map.getOptions()

    if not os.path.exists(options.model_folder):
        os.makedirs(options.model_folder)
    logger_path = os.path.join(options.model_folder, 'train.log')
    if not options.dev:
        sys.stdout = common_utils.Logger(logger_path)

    if options.dev:
        options.train_dataset = options.train_dataset.replace('train.', 'dev.')
        options.val_dataset = options.val_dataset.replace('val.', 'dev.')

    print('Args:\n%s\n' % pprint.pformat(vars(options)))

    if options.gpu < 0:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:%d' % options.gpu)

    common_utils.set_all_seeds(options.seed)

    model = Executor(args['executor'], options.num_resource_bin).to(device)
    inst_dict = model.inst_dict
    print(model)
    # model = nn.DataParallel(model, [0, 1])

    train_dataset = BehaviorCloneDataset(
        options.train_dataset,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        inst_dict=inst_dict,
        word_based=is_word_based(args['executor'].inst_encoder_type))
    val_dataset = BehaviorCloneDataset(options.val_dataset,
                                       options.num_resource_bin,
                                       options.resource_bin_size,
                                       options.max_num_prev_cmds,
                                       inst_dict=inst_dict,
                                       word_based=is_word_based(
                                           args['executor'].inst_encoder_type))

    if options.optim == 'adamax':
        optimizer = torch.optim.Adamax(model.parameters(),
                                       lr=options.lr,
                                       betas=(options.beta1, options.beta2))
    elif options.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=options.lr,
                                     betas=(options.beta1, options.beta2))
    else:
        assert False, 'not supported'

    train_loader = DataLoader(
        train_dataset,
        options.batch_size,
        shuffle=True,
        num_workers=20,  # if options.dev else 20,
        pin_memory=(options.gpu >= 0))
    val_loader = DataLoader(
        val_dataset,
        options.batch_size,
        shuffle=False,
        num_workers=20,  # if options.dev else 20,
        pin_memory=(options.gpu >= 0))

    best_eval_nll = float('inf')
    overfit_count = 0

    train_stat = common_utils.MultiCounter(
        os.path.join(options.model_folder, 'train'))
    eval_stat = common_utils.MultiCounter(
        os.path.join(options.model_folder, 'eval'))
    for epoch in range(1, options.epochs + 1):
        train_stat.start_timer()
        train(model, device, optimizer, options.grad_clip, train_loader, epoch,
              train_stat)
        train_stat.summary(epoch)
        train_stat.reset()

        with torch.no_grad(), common_utils.EvalMode(model):
            eval_stat.start_timer()
            eval_nll = evaluate(model, device, val_loader, epoch, eval_stat)
            eval_stat.summary(epoch)
            eval_stat.reset()

        model_file = os.path.join(options.model_folder,
                                  'checkpoint%d.pt' % epoch)
        print('saving model to', model_file)
        if isinstance(model, nn.DataParallel):
            model.module.save(model_file)
        else:
            model.save(model_file)

        if eval_nll < best_eval_nll:
            print('!!!New Best Model')
            overfit_count = 0
            best_eval_nll = eval_nll
            best_model_file = os.path.join(options.model_folder,
                                           'best_checkpoint.pt')
            print('saving best model to', best_model_file)
            if isinstance(model, nn.DataParallel):
                model.module.save(best_model_file)
            else:
                model.save(best_model_file)
        else:
            overfit_count += 1
            if overfit_count == 2:
                break

    print('train DONE')
Esempio n. 5
0
    parser.add_argument("--seed", default=10001, type=int, help="Random seed")

    parser.add_argument("--num_gpu", default=1, type=int)
    parser.add_argument("--replay_buffer_size", default=2**21, type=int)
    parser.add_argument("--burn_in_frames", default=1000, type=int)

    args = parser.parse_args()
    args.num_action = get_num_action(args.game)
    return args


if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True

    args = parse_args()
    sys.stdout = common_utils.Logger(
        os.path.join(args.save_dir, "benchmark.log"))
    sys.stderr = common_utils.Logger(
        os.path.join(args.save_dir, "benchmark.err"))

    # thread_game_worker = [(10, 10)]
    # for num_gpu in [1, 2, 3, 4, 5, 6]:
    thread_game_worker = [
        (80, 20),
        (80, 40),
        (80, 80),
        (80, 160),
        # (120, 20), (120, 40), (120, 80), (120, 160),
        # (160, 20), (160, 40), (160, 80), (160, 160),
    ]
    # args.num_gpu = num_gpu
    benchmark(thread_game_worker, args)