Ejemplo n.º 1
0
def main():
    torch.backends.cudnn.benchmark = True

    parser = common_utils.Parser()
    parser.add_parser('main', get_main_parser())
    parser.add_parser('executor', Executor.get_arg_parser())
    args = parser.parse()
    parser.log()

    options = args['main']
    # print('Args:\n%s\n' % pprint.pformat(vars(options)))

    # option_map = parse_args()
    # options = option_map.getOptions()

    if not os.path.exists(options.model_folder):
        os.makedirs(options.model_folder)
    logger_path = os.path.join(options.model_folder, 'train.log')
    if not options.dev:
        sys.stdout = common_utils.Logger(logger_path)

    if options.dev:
        options.train_dataset = options.train_dataset.replace('train.', 'dev.')
        options.val_dataset = options.val_dataset.replace('val.', 'dev.')

    print('Args:\n%s\n' % pprint.pformat(vars(options)))

    if options.gpu < 0:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:%d' % options.gpu)

    common_utils.set_all_seeds(options.seed)

    model = Executor(args['executor'], options.num_resource_bin).to(device)
    inst_dict = model.inst_dict
    print(model)
    # model = nn.DataParallel(model, [0, 1])

    train_dataset = BehaviorCloneDataset(
        options.train_dataset,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        inst_dict=inst_dict,
        word_based=is_word_based(args['executor'].inst_encoder_type))
    val_dataset = BehaviorCloneDataset(options.val_dataset,
                                       options.num_resource_bin,
                                       options.resource_bin_size,
                                       options.max_num_prev_cmds,
                                       inst_dict=inst_dict,
                                       word_based=is_word_based(
                                           args['executor'].inst_encoder_type))

    if options.optim == 'adamax':
        optimizer = torch.optim.Adamax(model.parameters(),
                                       lr=options.lr,
                                       betas=(options.beta1, options.beta2))
    elif options.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=options.lr,
                                     betas=(options.beta1, options.beta2))
    else:
        assert False, 'not supported'

    train_loader = DataLoader(
        train_dataset,
        options.batch_size,
        shuffle=True,
        num_workers=20,  # if options.dev else 20,
        pin_memory=(options.gpu >= 0))
    val_loader = DataLoader(
        val_dataset,
        options.batch_size,
        shuffle=False,
        num_workers=20,  # if options.dev else 20,
        pin_memory=(options.gpu >= 0))

    best_eval_nll = float('inf')
    overfit_count = 0

    train_stat = common_utils.MultiCounter(
        os.path.join(options.model_folder, 'train'))
    eval_stat = common_utils.MultiCounter(
        os.path.join(options.model_folder, 'eval'))
    for epoch in range(1, options.epochs + 1):
        train_stat.start_timer()
        train(model, device, optimizer, options.grad_clip, train_loader, epoch,
              train_stat)
        train_stat.summary(epoch)
        train_stat.reset()

        with torch.no_grad(), common_utils.EvalMode(model):
            eval_stat.start_timer()
            eval_nll = evaluate(model, device, val_loader, epoch, eval_stat)
            eval_stat.summary(epoch)
            eval_stat.reset()

        model_file = os.path.join(options.model_folder,
                                  'checkpoint%d.pt' % epoch)
        print('saving model to', model_file)
        if isinstance(model, nn.DataParallel):
            model.module.save(model_file)
        else:
            model.save(model_file)

        if eval_nll < best_eval_nll:
            print('!!!New Best Model')
            overfit_count = 0
            best_eval_nll = eval_nll
            best_model_file = os.path.join(options.model_folder,
                                           'best_checkpoint.pt')
            print('saving best model to', best_model_file)
            if isinstance(model, nn.DataParallel):
                model.module.save(best_model_file)
            else:
                model.save(best_model_file)
        else:
            overfit_count += 1
            if overfit_count == 2:
                break

    print('train DONE')
Ejemplo n.º 2
0
        games,
    )
    act_group.start()
    context.start()
    while replay_buffer.size() < args.burn_in_frames:
        print("warming up replay buffer:", replay_buffer.size())
        time.sleep(1)

    print("Success, Done")
    print("=======================")

    frame_stat = dict()
    frame_stat["num_acts"] = 0
    frame_stat["num_buffer"] = 0

    stat = common_utils.MultiCounter(args.save_dir)
    tachometer = utils.Tachometer()
    stopwatch = common_utils.Stopwatch()

    for epoch in range(args.num_epoch):
        print("beginning of epoch: ", epoch)
        print(common_utils.get_mem_usage())
        tachometer.start()
        stat.reset()
        stopwatch.reset()

        for batch_idx in range(args.epoch_len):
            num_update = batch_idx + epoch * args.epoch_len
            if num_update % args.num_update_between_sync == 0:
                agent.sync_target_with_online()
            if num_update % args.actor_sync_freq == 0:
Ejemplo n.º 3
0
def main(args):
    # torch.backends.cudnn.benchmark = True
    log.info(f"Working dir: {os.getcwd()}")
    log.info("\n" + common_utils.get_git_hash())
    log.info("\n" + common_utils.get_git_diffs())

    common_utils.set_all_seeds(args.seed)
    log.info(args.pretty())

    # Setup latent variables.
    root = Latent("s", args)
    iterator = LatentIterator(root)
    dataset = SimpleDataset(iterator, args.N)

    bs = args.batchsize
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=bs,
                                         shuffle=True,
                                         num_workers=4)

    load_saved = False
    need_train = True

    if load_saved:
        model = torch.load(args.save_file)
    else:
        model = Model(iterator, args.hid, args.eps)
    model.cuda()

    if need_train:
        loss_func = nn.MSELoss().cuda()
        optimizer = optim.SGD(model.parameters(), lr=args.lr)

        stats = common_utils.MultiCounter("./")
        stats_corrs = {
            parent.name: common_utils.StatsCorr()
            for parent in iterator.top_down()
        }

        for i in range(args.num_epoch):
            for _, v in stats_corrs.items():
                v.reset()

            connections = dict()
            n_samples = [0, 0]

            with torch.no_grad():
                for batch in tqdm.tqdm(loader, total=int(len(loader))):
                    d = batch2dict(batch)
                    label = d["x_label"]
                    f, d_hs, d_inputs = model(d["x"])
                    Js = model.computeJ(d_hs)

                    # Correlation.
                    # batch_size * K
                    d_gt = dataset.split_generated(d["x_all"])
                    for v in iterator.top_down():
                        name = v.name
                        stats_corrs[name].add(d_gt[name].unsqueeze(1).float(),
                                              d_hs[name])

                        J = Js[name]
                        inputs = d_inputs[name].detach()
                        conn = torch.einsum("ia,ibc->iabc", inputs, J)
                        conn = conn.view(conn.size(0),
                                         conn.size(1) * conn.size(2),
                                         conn.size(3))
                        # group by labels.
                        conn0 = conn[label == 0, :, :].sum(dim=0)
                        conn1 = conn[label == 1, :, :].sum(dim=0)
                        conns = torch.stack([conn0, conn1])

                        # Accumulate connection.
                        if name in connections:
                            connections[name] += conns
                        else:
                            connections[name] = conns

                    for j in range(2):
                        n_samples[j] += (label == j).sum().item()

            json_result = dict(epoch=i)
            for name in connections.keys():
                conns = connections[name]
                n_total_sample = n_samples[0] + n_samples[1]
                avg_conn = conns.sum(dim=0) / n_total_sample

                cov_op = torch.zeros(avg_conn.size(0),
                                     avg_conn.size(0)).to(avg_conn.device)

                for j in range(2):
                    conns[j, :, :] /= n_samples[j]
                    diff = conns[j, :, :] - avg_conn
                    cov_op += diff @ diff.t() * n_samples[j] / n_total_sample

                dd = cov_op.size(0)
                json_result["conn_" + name] = dict(size=dd,
                                                   norm=cov_op.norm().item())
                json_result["weight_norm_" +
                            name] = model.nets[name].weight.norm().item()

            layer_avgs = [[0, 0] for j in range(args.depth + 1)]
            for p in iterator.top_down():
                corr = stats_corrs[p.name].get()["corr"]
                # Note that we need to take absolute value (since -1/+1 are both good)
                res = common_utils.corr_summary(corr.abs())
                best = res["best_corr"].item()
                json_result["best_corr_" + p.name] = best

                layer_avgs[p.depth][0] += best
                layer_avgs[p.depth][1] += 1

            # Check average correlation for each layer
            # log.info("CovOp norm at every location:")
            for d, (sum_corr, n) in enumerate(layer_avgs):
                if n > 0:
                    log.info(
                        f"[{d}] Mean of the best corr: {sum_corr/n:.3f} [{n}]")

            log.info(f"json_str: {json.dumps(json_result)}")

            # Training
            stats.reset()
            for batch in tqdm.tqdm(loader, total=int(len(loader))):
                optimizer.zero_grad()

                d = batch2dict(batch)

                f, _, _ = model(d["x"])
                f_pos, _, _ = model(d["x_pos"])
                f_neg, _, _ = model(d["x_neg"])

                pos_loss = loss_func(f, f_pos)
                neg_loss = loss_func(f, f_neg)

                # import pdb
                # pdb.set_trace()
                if args.loss == "nce":
                    loss = -(-pos_loss / args.temp).exp() / (
                        (-pos_loss / args.temp).exp() +
                        (-neg_loss / args.temp).exp())
                elif args.loss == "subtract":
                    loss = pos_loss - neg_loss
                else:
                    raise NotImplementedError(f"{args.loss} is unknown")

                #loss = pos_loss.exp() / ( pos_loss.exp() + neg_loss.exp())
                stats["train_loss"].feed(loss.detach().item())

                loss.backward()
                optimizer.step()

            log.info("\n" + stats.summary(i))
            '''
            measures = generator.check(model.linear1.weight.detach())
            for k, v in measures.items():
                for vv in v:
                    stats["stats_" + k].feed(vv)
            '''

            # log.info(f"\n{best_corrs}\n")

    torch.save(model, args.save_file)