Ejemplo n.º 1
0
def train(args,
          epochs,
          trainloader,
          valloader,
          model,
          optimiser,
          loss_fn,
          logger=None,
          metric_list=None,
          cuda=True):
    pb = tqdm(total=epochs, unit_scale=True, smoothing=0.1, ncols=150)
    update_frac = 1. / float(len(trainloader) + len(valloader))
    global_step = 0 if not hasattr(
        args, 'global_step') or args.global_step is None else args.global_step
    loss, val_loss = torch.tensor(0), torch.tensor(0)
    mean_logs = {}

    for i in range(epochs):
        for t, data in enumerate(trainloader):
            optimiser.zero_grad()
            model.train()
            data = to_cuda(data) if cuda else data
            out = model.train_step(data, t, loss_fn)
            loss = out['loss']
            loss.backward()
            optimiser.step()
            pb.update(update_frac)
            pgs = [pg['lr'] for pg in optimiser.param_groups]
            pb.set_postfix_str(
                'ver:{}, loss:{:.3f}, val_loss:{:.3f}, lr:{}'.format(
                    logger.get_version(), loss.item(), val_loss.item(), pgs))
            global_step += 1

        log_list = []
        with torch.no_grad():
            for t, data in enumerate(valloader):
                model.eval()
                to_cuda(data) if cuda else None
                out = model.val_step(data, t, loss_fn)
                val_loss = out['loss']
                logs = out['out']
                log_list.append(
                    parse_val_logs(t, args, model, data, logger, metric_list,
                                   logs, out['state'], global_step))
                pb.update(update_frac)
                pb.set_postfix_str(
                    'ver:{}, loss:{:.3f}, val_loss:{:.3f}'.format(
                        logger.get_version(), loss.item(), val_loss.item()))
                global_step += 1

        mean_logs = mean_log_list(log_list)
        logger.write_dict(mean_logs,
                          global_step) if logger is not None else None
        save_model(logger, model, args)
    return mean_logs
Ejemplo n.º 2
0
def run(args):
    if args.evaluate or args.load_model:
        checkpoint_path = os.path.join(args.log_path, 'checkpoints')
        model_state, old_args = model_loader(checkpoint_path)
    if args.evaluate:
        old_args.data_path, old_args.log_path = args.data_path, args.log_path
        old_args.evaluate, old_args.visualise, old_args.metrics = args.evaluate, args.visualise, args.metrics
        args = old_args

    args.nc, args.factors = dataset_meta[args.dataset]['nc'], dataset_meta[
        args.dataset]['factors']
    trainds, valds = datasets[args.dataset](args)
    trainloader, valloader = set_to_loader(trainds, valds, args)

    model = models[args.model](args)

    model.load_state_dict(
        model_state) if args.evaluate or args.load_model else None
    model.cuda()

    if args.base_model_path is not None:
        model_state, _ = model_loader(args.base_model_path)
        model.load_vae_state(model_state)

    try:
        if args.policy_learning_rate is None:
            args.policy_learning_rate = args.learning_rate
        optimiser = torch.optim.Adam([{
            'params': model.vae_params(),
            'lr': args.learning_rate * 1
        }, {
            'params': model.action_params(),
            'lr': args.policy_learning_rate * 1
        }, {
            'params': model.group_params(),
            'lr': args.group_learning_rate
        }], )
    except:
        print(
            'Failed to use vae-action-group optimiser setup. Falling back to .parameters() optimiser'
        )
        optimiser = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    paired = True if args.model in ['rgrvae', 'forward', 'dforward'] else False
    loss_fn = lambda x_hat, x: (x_hat.sigmoid() - x).pow(2).sum() / x.shape[0]
    metric_list = MetricAggregator(valds.dataset, 1000, model,
                                   paired) if args.metrics else None

    version = None
    if args.log_path is not None and args.load_model:
        for a in args.log_path.split('/'):
            if 'version_' in a:
                version = a.split('_')[-1]

    logger = Logger('./logs/', version)
    param_count = count_parameters(model)
    logger.writer.add_text('parameters/number_params',
                           param_count.replace('\n', '\n\n'), 0)
    print(param_count)

    write_args(args, logger)
    if not args.evaluate:
        out = train(args, args.epochs, trainloader, valloader, model,
                    optimiser, loss_fn, logger, metric_list, True)
    else:
        out = {}

    if args.evaluate or args.end_metrics:
        log_list = MetricAggregator(trainds.dataset,
                                    valds.dataset,
                                    1000,
                                    model,
                                    paired,
                                    args.latents,
                                    ntrue_actions=args.latents,
                                    final=True)()
        mean_logs = mean_log_list([
            log_list,
        ])
        logger.write_dict(mean_logs, model.global_step +
                          1) if logger is not None else None

    gc.collect()
    return out
def run(args):
    if args.evaluate or args.load_model:
        checkpoint_path = os.path.join(args.log_path, 'checkpoints')
        model_state, old_args = model_loader(checkpoint_path)
    if args.evaluate:
        old_args.data_path, old_args.log_path = args.data_path, args.log_path
        old_args.evaluate, old_args.visualise, old_args.metrics = args.evaluate, args.visualise, args.metrics
        old_args.eval_dataset, old_args.eval_data_path = args.eval_dataset, args.eval_data_path
        old_args.split = args.split
        args = old_args

    args.nc, args.factors = dataset_meta[args.dataset]['nc'], dataset_meta[
        args.dataset]['factors']
    trainds, valds = datasets[args.dataset](args)
    trainloader, valloader = set_to_loader(trainds, valds, args)

    model = models[args.model](args)

    model.load_state_dict(
        model_state) if args.evaluate or args.load_model else None
    model.cuda()

    if args.base_model_path is not None:
        model_state, _ = model_loader(args.base_model_path)
        model.load_vae_state(model_state)

    # if args.model == 'lie_group_rl' and not args.supervised_train:
    # print('Using separate optimisers for each sub module.')
    # if args.policy_learning_rate is None:
    # args.policy_learning_rate = args.learning_rate
    # optimiser_ls = [torch.optim.Adam([{'params': model.vae_params(), 'lr': args.learning_rate * 1},
    # {'params': model.action_params(), 'lr': args.policy_learning_rate * 1}]),
    # torch.optim.Adam(model.group_params(), lr=args.group_learning_rate)]
    # else:
    try:
        if args.policy_learning_rate is None:
            args.policy_learning_rate = args.learning_rate
        optimiser = torch.optim.Adam([{
            'params': model.vae_params(),
            'lr': args.learning_rate * 1
        }, {
            'params': model.action_params(),
            'lr': args.policy_learning_rate * 1
        }, {
            'params': model.group_params(),
            'lr': args.group_learning_rate
        }], )
    except:
        print(
            'Failed to use vae-action-group optimiser setup. Falling back to .parameters() optimiser'
        )
        optimiser = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    paired = True if args.model in ['rgrvae', 'forward', 'dforward'] else False
    if args.recons_loss_type == 'l2':
        loss_fn = lambda x_hat, x: (x_hat.sigmoid() - x).pow(2).sum(
        ) / x.shape[0]
    else:
        loss_fn = lambda x_hat, x: F.binary_cross_entropy_with_logits(
            x_hat.view(x_hat.size(0), -1),
            x.view(x.size(0), -1),
            reduction='sum') / x.shape[0]
    metric_list = MetricAggregator(trainds,
                                   trainds,
                                   1000,
                                   model,
                                   paired,
                                   args.latents,
                                   ntrue_actions=args.latents,
                                   final=True) if args.metrics else None

    version = None
    if args.log_path is not None and args.load_model:
        for a in args.log_path.split('/'):
            if 'version_' in a:
                version = a.split('_')[-1]

    # logger = Logger('./logs/', version)
    logger = Logger(args.log_path, version)
    param_count = count_parameters(model)
    logger.writer.add_text('parameters/number_params',
                           param_count.replace('\n', '\n\n'), 0)
    print(param_count)

    write_args(args, logger)
    if not args.evaluate:
        # if args.model == 'lie_group_rl' and not args.supervised_train:
        # out = train_lie(args, args.epochs, trainloader, valloader, model, optimiser_ls, loss_fn, logger, metric_list, True)
        # else:
        out = train(args, args.epochs, trainloader, valloader, model,
                    optimiser, loss_fn, logger, metric_list, True)
    else:
        out = {}

    if args.evaluate or args.end_metrics:
        if args.eval_dataset and args.eval_data_path:
            del trainds
            del trainloader
            del valloader
            args.dataset = args.eval_dataset
            args.data_path = args.eval_data_path
            trainds, _ = datasets[args.eval_dataset](args)
        log_list = MetricAggregator(trainds,
                                    trainds,
                                    1000,
                                    model,
                                    paired,
                                    args.latents,
                                    ntrue_actions=args.latents,
                                    final=True)()
        mean_logs = mean_log_list([
            log_list,
        ])
        logger.write_dict(mean_logs, model.global_step +
                          1) if logger is not None else None

    gc.collect()
    return out