def train(args, epochs, trainloader, valloader, model, optimiser, loss_fn, logger=None, metric_list=None, cuda=True): pb = tqdm(total=epochs, unit_scale=True, smoothing=0.1, ncols=150) update_frac = 1. / float(len(trainloader) + len(valloader)) global_step = 0 if not hasattr( args, 'global_step') or args.global_step is None else args.global_step loss, val_loss = torch.tensor(0), torch.tensor(0) mean_logs = {} for i in range(epochs): for t, data in enumerate(trainloader): optimiser.zero_grad() model.train() data = to_cuda(data) if cuda else data out = model.train_step(data, t, loss_fn) loss = out['loss'] loss.backward() optimiser.step() pb.update(update_frac) pgs = [pg['lr'] for pg in optimiser.param_groups] pb.set_postfix_str( 'ver:{}, loss:{:.3f}, val_loss:{:.3f}, lr:{}'.format( logger.get_version(), loss.item(), val_loss.item(), pgs)) global_step += 1 log_list = [] with torch.no_grad(): for t, data in enumerate(valloader): model.eval() to_cuda(data) if cuda else None out = model.val_step(data, t, loss_fn) val_loss = out['loss'] logs = out['out'] log_list.append( parse_val_logs(t, args, model, data, logger, metric_list, logs, out['state'], global_step)) pb.update(update_frac) pb.set_postfix_str( 'ver:{}, loss:{:.3f}, val_loss:{:.3f}'.format( logger.get_version(), loss.item(), val_loss.item())) global_step += 1 mean_logs = mean_log_list(log_list) logger.write_dict(mean_logs, global_step) if logger is not None else None save_model(logger, model, args) return mean_logs
def run(args): if args.evaluate or args.load_model: checkpoint_path = os.path.join(args.log_path, 'checkpoints') model_state, old_args = model_loader(checkpoint_path) if args.evaluate: old_args.data_path, old_args.log_path = args.data_path, args.log_path old_args.evaluate, old_args.visualise, old_args.metrics = args.evaluate, args.visualise, args.metrics args = old_args args.nc, args.factors = dataset_meta[args.dataset]['nc'], dataset_meta[ args.dataset]['factors'] trainds, valds = datasets[args.dataset](args) trainloader, valloader = set_to_loader(trainds, valds, args) model = models[args.model](args) model.load_state_dict( model_state) if args.evaluate or args.load_model else None model.cuda() if args.base_model_path is not None: model_state, _ = model_loader(args.base_model_path) model.load_vae_state(model_state) try: if args.policy_learning_rate is None: args.policy_learning_rate = args.learning_rate optimiser = torch.optim.Adam([{ 'params': model.vae_params(), 'lr': args.learning_rate * 1 }, { 'params': model.action_params(), 'lr': args.policy_learning_rate * 1 }, { 'params': model.group_params(), 'lr': args.group_learning_rate }], ) except: print( 'Failed to use vae-action-group optimiser setup. Falling back to .parameters() optimiser' ) optimiser = torch.optim.Adam(model.parameters(), lr=args.learning_rate) paired = True if args.model in ['rgrvae', 'forward', 'dforward'] else False loss_fn = lambda x_hat, x: (x_hat.sigmoid() - x).pow(2).sum() / x.shape[0] metric_list = MetricAggregator(valds.dataset, 1000, model, paired) if args.metrics else None version = None if args.log_path is not None and args.load_model: for a in args.log_path.split('/'): if 'version_' in a: version = a.split('_')[-1] logger = Logger('./logs/', version) param_count = count_parameters(model) logger.writer.add_text('parameters/number_params', param_count.replace('\n', '\n\n'), 0) print(param_count) write_args(args, logger) if not args.evaluate: out = train(args, args.epochs, trainloader, valloader, model, optimiser, loss_fn, logger, metric_list, True) else: out = {} if args.evaluate or args.end_metrics: log_list = MetricAggregator(trainds.dataset, valds.dataset, 1000, model, paired, args.latents, ntrue_actions=args.latents, final=True)() mean_logs = mean_log_list([ log_list, ]) logger.write_dict(mean_logs, model.global_step + 1) if logger is not None else None gc.collect() return out
def run(args): if args.evaluate or args.load_model: checkpoint_path = os.path.join(args.log_path, 'checkpoints') model_state, old_args = model_loader(checkpoint_path) if args.evaluate: old_args.data_path, old_args.log_path = args.data_path, args.log_path old_args.evaluate, old_args.visualise, old_args.metrics = args.evaluate, args.visualise, args.metrics old_args.eval_dataset, old_args.eval_data_path = args.eval_dataset, args.eval_data_path old_args.split = args.split args = old_args args.nc, args.factors = dataset_meta[args.dataset]['nc'], dataset_meta[ args.dataset]['factors'] trainds, valds = datasets[args.dataset](args) trainloader, valloader = set_to_loader(trainds, valds, args) model = models[args.model](args) model.load_state_dict( model_state) if args.evaluate or args.load_model else None model.cuda() if args.base_model_path is not None: model_state, _ = model_loader(args.base_model_path) model.load_vae_state(model_state) # if args.model == 'lie_group_rl' and not args.supervised_train: # print('Using separate optimisers for each sub module.') # if args.policy_learning_rate is None: # args.policy_learning_rate = args.learning_rate # optimiser_ls = [torch.optim.Adam([{'params': model.vae_params(), 'lr': args.learning_rate * 1}, # {'params': model.action_params(), 'lr': args.policy_learning_rate * 1}]), # torch.optim.Adam(model.group_params(), lr=args.group_learning_rate)] # else: try: if args.policy_learning_rate is None: args.policy_learning_rate = args.learning_rate optimiser = torch.optim.Adam([{ 'params': model.vae_params(), 'lr': args.learning_rate * 1 }, { 'params': model.action_params(), 'lr': args.policy_learning_rate * 1 }, { 'params': model.group_params(), 'lr': args.group_learning_rate }], ) except: print( 'Failed to use vae-action-group optimiser setup. Falling back to .parameters() optimiser' ) optimiser = torch.optim.Adam(model.parameters(), lr=args.learning_rate) paired = True if args.model in ['rgrvae', 'forward', 'dforward'] else False if args.recons_loss_type == 'l2': loss_fn = lambda x_hat, x: (x_hat.sigmoid() - x).pow(2).sum( ) / x.shape[0] else: loss_fn = lambda x_hat, x: F.binary_cross_entropy_with_logits( x_hat.view(x_hat.size(0), -1), x.view(x.size(0), -1), reduction='sum') / x.shape[0] metric_list = MetricAggregator(trainds, trainds, 1000, model, paired, args.latents, ntrue_actions=args.latents, final=True) if args.metrics else None version = None if args.log_path is not None and args.load_model: for a in args.log_path.split('/'): if 'version_' in a: version = a.split('_')[-1] # logger = Logger('./logs/', version) logger = Logger(args.log_path, version) param_count = count_parameters(model) logger.writer.add_text('parameters/number_params', param_count.replace('\n', '\n\n'), 0) print(param_count) write_args(args, logger) if not args.evaluate: # if args.model == 'lie_group_rl' and not args.supervised_train: # out = train_lie(args, args.epochs, trainloader, valloader, model, optimiser_ls, loss_fn, logger, metric_list, True) # else: out = train(args, args.epochs, trainloader, valloader, model, optimiser, loss_fn, logger, metric_list, True) else: out = {} if args.evaluate or args.end_metrics: if args.eval_dataset and args.eval_data_path: del trainds del trainloader del valloader args.dataset = args.eval_dataset args.data_path = args.eval_data_path trainds, _ = datasets[args.eval_dataset](args) log_list = MetricAggregator(trainds, trainds, 1000, model, paired, args.latents, ntrue_actions=args.latents, final=True)() mean_logs = mean_log_list([ log_list, ]) logger.write_dict(mean_logs, model.global_step + 1) if logger is not None else None gc.collect() return out