Exemplo n.º 1
0
    def __init__(self, args, data_id, model_id, optim_id, train_loader,
                 eval_loader, model, optimizer, scheduler_iter,
                 scheduler_epoch):

        # Edit args
        if args.eval_every is None:
            args.eval_every = args.epochs
        if args.check_every is None:
            args.check_every = args.epochs
        if args.name is None:
            args.name = time.strftime("%Y-%m-%d_%H-%M-%S")
        if args.project is None:
            args.project = '_'.join([data_id, model_id])

        # Move model
        model = model.to(args.device)
        if args.parallel == 'dp':
            model = DataParallelDistribution(model)

        # Init parent
        log_path = os.path.join(self.log_base, data_id, model_id, optim_id,
                                args.name)
        super(FlowExperiment, self).__init__(model=model,
                                             optimizer=optimizer,
                                             scheduler_iter=scheduler_iter,
                                             scheduler_epoch=scheduler_epoch,
                                             log_path=log_path,
                                             eval_every=args.eval_every,
                                             check_every=args.check_every)

        # Store args
        self.create_folders()
        self.save_args(args)
        self.args = args

        # Store IDs
        self.data_id = data_id
        self.model_id = model_id
        self.optim_id = optim_id

        # Store data loaders
        self.train_loader = train_loader
        self.eval_loader = eval_loader

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args",
                                 get_args_table(args_dict).get_html_string(),
                                 global_step=0)
        if args.log_wandb:
            wandb.init(config=args_dict,
                       project=args.project,
                       id=args.name,
                       dir=self.log_path)
Exemplo n.º 2
0
    def save_args(self, args):

        # Save args
        with open(os.path.join(self.log_path, 'args.pickle'), "wb") as f:
            pickle.dump(args, f)

        # Save args table
        args_table = get_args_table(vars(args))
        with open(os.path.join(self.log_path, 'args_table.txt'), "w") as f:
            f.write(str(args_table))
Exemplo n.º 3
0
    # Train
    print('Training...')
    writer = SummaryWriter(tb_writer_path)
    for epoch in range(args.epochs):
        train_ppll = train(model, train_loader, epoch=epoch)
        writer.add_scalar('train_ppll', train_ppll, global_step=epoch + 1)
        if (epoch + 1) % args.valid_every == 0:
            valid_ppll = evaluate(model, valid_loader)
            writer.add_scalar('valid_ppll', valid_ppll, global_step=epoch + 1)

    # Save log-likelihood
    with open('results/{}_valid_loglik.txt'.format(run_name), 'w') as f:
        f.write(str(valid_ppll))

    # Save args
    args_table = get_args_table(vars(args))
    with open('results/{}_args.txt'.format(run_name), 'w') as f:
        f.write(str(args_table))

    # Save model
    state_dict = model.state_dict()
    torch.save(state_dict, 'models/{}.pt'.format(run_name))

##########
## Test ##
##########

# Test
if args.iwbo_k is None:
    test_ppll = evaluate(model, test_loader)
else: