예제 #1
0
def train(args):

    net = HypothesisNet(args)

    if not args.no_log:
        log = log_this(net.args, 'logs/sim', args.name, checkpoints=False)

    simulator = net.simulator
    if args.no_reservoir:
        layer = net.W_ro
    else:
        layer = net.reservoir

    batch_size = 10

    criterion = nn.MSELoss()
    train_params = simulator.parameters()
    optimizer = optim.Adam(train_params, lr=args.lr)

    for i in range(args.iters):

        if not args.no_reservoir:
            layer.reset('random')
        optimizer.zero_grad()

        prop = torch.Tensor(
            np.random.normal(0, 5, size=(batch_size, net.args.D)))
        state = torch.Tensor(
            np.random.normal(0, 10, size=(batch_size, net.args.L)))
        sim_out = simulator(state, prop)

        # run reservoir 10 steps, so predict 10 steps in future
        outs = []
        for j in range(args.forward_steps):
            outs.append(layer(prop))

        actions = sum(outs)

        # get state output
        layer_out = actions + state

        # validation makes sure performance is poor if we use someone else's output
        layer_out_val = actions.roll(1, 0) + state

        # calculate euclidean loss
        loss = criterion(layer_out, sim_out)
        loss_val = criterion(layer_out_val, sim_out)

        loss.backward()
        optimizer.step()

        if i % 50 == 0 and i != 0:
            print(f'iteration: {i} | loss {loss} | loss_val {loss_val}')

    if not args.no_log:
        save_model_path = os.path.join(log.run_dir, f'model_{log.run_id}.pth')
        save_sim_path = os.path.join(log.run_dir, f'sim_{log.run_id}.pth')
        torch.save(net.state_dict(), save_model_path)
        torch.save(simulator.state_dict(), save_sim_path)
        print(f'saved model to {save_model_path}, sim to {save_sim_path}')
예제 #2
0
def train(args):

    net = HypothesisNet(args)

    if not args.no_log:
        log = log_this(net.args, 'logs/hyp', args.name, checkpoints=False)

    simulator = net.simulator
    hypothesizer = net.hypothesizer

    batch_size = 10

    criterion = nn.MSELoss()
    train_params = hypothesizer.parameters()
    optimizer = optim.Adam(train_params, lr=1e-3)

    for i in range(args.iters):

        optimizer.zero_grad()

        state = torch.Tensor(
            np.random.normal(0, 10, size=(batch_size, net.args.L)))
        task = torch.Tensor(
            np.random.normal(0, 10, size=(batch_size, net.args.L)))

        prop = hypothesizer(state, task)
        sim_out = simulator(state, prop)

        # run reservoir 10 steps, so predict 10 steps in future
        outs = []
        for j in range(10):
            outs.append(layer(prop))

        actions = sum(outs)

        # get state output
        layer_out = actions + state

        # validation makes sure performance is poor if we use someone else's output
        layer_out_val = actions.roll(1, 0) + state

        # calculate euclidean loss
        diff = torch.norm(layer_out - sim_out, dim=1)
        loss = criterion(diff, torch.zeros_like(diff))

        diff_val = torch.norm(layer_out_val - sim_out, dim=1)
        loss_val = criterion(diff_val, torch.zeros_like(diff_val))

        loss.backward()
        optimizer.step()

        if i % 50 == 0 and i != 0:
            print(f'iteration: {i} | loss {loss} | loss_val {loss_val}')

    if not args.no_log:
        save_model_path = os.path.join(log.run_dir, f'model_{log.run_id}.pth')
        torch.save(net.state_dict(), save_model_path)
        print(f'saved model to {save_model_path}')
예제 #3
0
class Trainer:
    def __init__(self, args):
        super().__init__()

        self.args = args

        if self.args.net == 'basic':
            self.net = BasicNetwork(self.args)
        elif self.args.net == 'state':
            self.net = StateNet(self.args)
        elif self.args.net == 'hypothesis':
            self.net = HypothesisNet(self.args)

        # picks which parameters to train and which not to train
        self.n_params = {}
        self.train_params = []
        self.not_train_params = []
        logging.info('Training the following parameters:')
        for k, v in self.net.named_parameters():
            # k is name, v is weight
            found = False
            # filtering just for the parts that will be trained
            for part in self.args.train_parts:
                if part in k:
                    logging.info(f'  {k}')
                    self.n_params[k] = (v.shape, v.numel())
                    self.train_params.append(v)
                    found = True
                    break
            if not found:
                self.not_train_params.append(k)
        logging.info('Not training:')
        for k in self.not_train_params:
            logging.info(f'  {k}')

        self.criterion = get_criterion(self.args)
        self.optimizer = get_optimizer(self.args, self.train_params)
        self.dset = load_rb(self.args.dataset)
        self.potential = get_potential(self.args)

        # if using separate training and test sets, separate them out
        if not self.args.same_test:
            np.random.shuffle(self.dset)
            cutoff = round(.9 * len(self.dset))
            self.train_set = self.dset[:cutoff]
            self.test_set = self.dset[cutoff:]
            logging.info(
                f'Using separate training ({cutoff}) and test ({len(self.dset) - cutoff}) sets.'
            )
        else:
            self.train_set = self.dset
            self.test_set = self.dset

        self.log_interval = self.args.log_interval
        if not self.args.no_log:
            self.log = self.args.log
            self.run_id = self.args.log.run_id
            self.vis_samples = []
            self.csv_path = open(
                os.path.join(self.log.run_dir, f'losses_{self.run_id}.csv'),
                'a')
            self.writer = csv.writer(self.csv_path,
                                     delimiter=',',
                                     quotechar='|',
                                     quoting=csv.QUOTE_MINIMAL)
            self.writer.writerow(['ix', 'avg_loss'])
            self.plot_checkpoint_path = os.path.join(
                self.log.run_dir, f'checkpoints_{self.run_id}.pkl')
            self.save_model_path = os.path.join(self.log.run_dir,
                                                f'model_{self.run_id}.pth')

    def log_model(self, ix=0):
        # saving all checkpoints takes too much space so we just save one model at a time, unless we explicitly specify it
        if self.args.log_checkpoint_models:
            self.save_model_path = os.path.join(self.log.checkpoint_dir,
                                                f'model_{ix}.pth')
        elif os.path.exists(self.save_model_path):
            os.remove(self.save_model_path)
        torch.save(self.net.state_dict(), self.save_model_path)

    def log_checkpoint(self, ix, x, y, z, total_loss, avg_loss):
        self.writer.writerow([ix, avg_loss])
        self.csv_path.flush()

        self.log_model(ix)

        # we can save individual samples at each checkpoint, that's not too bad space-wise
        self.vis_samples.append([ix, x, y, z, total_loss, avg_loss])
        if os.path.exists(self.plot_checkpoint_path):
            os.remove(self.plot_checkpoint_path)
        with open(self.plot_checkpoint_path, 'wb') as f:
            pickle.dump(self.vis_samples, f)

    def train_iteration(self, x, y):
        self.net.reset()
        self.optimizer.zero_grad()

        outs = []
        total_loss = torch.tensor(0.)

        # ins is actual input into the network
        # targets is desired output
        # outs is output of network
        if self.args.dset_type == 'goals':
            ins = []
            l_other = {'kl': 0, 'lconf': 0, 'lsim': 0, 'lfprop': 0, 'lp': 0}
            targets = x
            cur_idx = torch.zeros(x.shape[0], dtype=torch.long)
            for j in range(self.args.goals_timesteps):
                net_out, step_loss, cur_idx, extras = self.run_iter_goal(
                    x, cur_idx)
                # what we need to record for logging
                ins.append(extras['in'])
                outs.append(net_out[-1].detach().numpy())
                total_loss += step_loss

                if 'kl' in extras and extras['kl'] is not None:
                    l_other['kl'] += extras['kl']
                if 'lconf' in extras and extras['lconf'] is not None:
                    l_other['lconf'] += extras['lconf']
                if 'lsim' in extras and extras['lsim'] is not None:
                    l_other['lsim'] += extras['lsim']
                if 'lp' in extras and extras['lp'] is not None:
                    l_other['lp'] += extras['lp']
                # if 'lfprop' in extras and extras['lfprop'] is not None:
                #     l_other['lfprop'] += extras['lfprop']

            ins = torch.cat(ins)

        else:
            ins = x
            targets = y
            for j in range(x.shape[1]):
                net_out, step_loss, extras = self.run_iter_traj(
                    x[:, j], y[:, j])
                if np.isnan(step_loss.item()):
                    return -1, (net_out, extras)
                total_loss += step_loss
                outs.append(net_out[-1].item())

        total_loss.backward()
        self.optimizer.step()

        etc = {
            'ins': ins,
            'targets': targets,
            'outs': outs,
            'prop': extras['prop'],
        }
        etc.update(l_other)
        if self.args.dset_type == 'goals':
            etc['indices'] = cur_idx

        return total_loss, etc

    # runs an iteration where we want to match a certain trajectory
    def run_iter_traj(self, x, y):
        net_in = x.reshape(-1, self.args.L)
        net_out, extras = self.net(net_in, extras=True)
        net_target = y.reshape(-1, self.args.Z)
        step_loss = self.criterion(net_out, net_target)

        return net_out, step_loss, extras

    # runs an iteration where we want to hit a certain goal (dynamic input)
    def run_iter_goal(self, x, indices):
        x_goal = x[torch.arange(x.shape[0]), indices, :]

        net_in = x_goal.reshape(-1, self.args.L)
        net_out, extras = self.net(net_in, extras=True)
        # the target is actually the input
        step_loss, new_indices = goals_loss(
            net_out, x, indices, threshold=self.args.goals_threshold)
        # it'll be None if we just started, or if we're not doing variational stuff

        # non-goals related losses
        # if net_out.shape[0] != 1:
        #     pdb.set_trace()
        extras['lp'] = self.potential(net_out).sum()
        step_loss += extras['lp']
        if 'kl' in extras and extras['kl'] is not None:
            step_loss += extras['kl']
        if 'lconf' in extras and extras['lconf'] is not None:
            step_loss += extras['lconf']
        if 'lsim' in extras and extras['lsim'] is not None:
            step_loss += extras['lsim']
        # if 'lfprop' in extras and extras['lfprop'] is not None:
        #     step_loss += extras['lfprop']

        extras.update({'in': net_in})

        return net_out, step_loss, new_indices, extras

    def test(self, n=0):
        if n != 0:
            assert n <= len(self.test_set)
            batch_idxs = np.random.choice(len(self.test_set), n)
            batch = [self.test_set[i] for i in batch_idxs]
        else:
            batch = self.test_set

        x, y = get_x_y(batch, self.args.dataset)

        with torch.no_grad():
            self.net.reset()
            total_loss = torch.tensor(0.)

            if self.args.dset_type == 'goals':
                cur_idx = torch.zeros(x.shape[0], dtype=torch.long)
                for j in range(self.args.goals_timesteps):
                    _, step_loss, cur_idx, _ = self.run_iter_goal(x, cur_idx)
                    total_loss += step_loss

            else:
                for j in range(x.shape[1]):
                    _, step_loss, _ = self.run_iter_traj(x[:, j], y[:, j])
                    total_loss += step_loss

        etc = {}
        if self.args.dset_type == 'goals':
            etc['indices'] = cur_idx

        return total_loss.item() / len(batch), etc

    def train(self, ix_callback=None):
        ix = 0

        its_p_epoch = len(self.train_set) // self.args.batch_size
        logging.info(
            f'Training set size {len(self.train_set)} | batch size {self.args.batch_size} --> {its_p_epoch} iterations / epoch'
        )

        # for convergence testing
        max_abs_grads = []
        running_min_error = float('inf')
        running_no_min = 0

        running_loss = 0.0
        # running_mag = 0.0
        ending = False
        for e in range(self.args.n_epochs):
            np.random.shuffle(self.train_set)
            epoch_idx = 0
            while epoch_idx < its_p_epoch:
                epoch_idx += 1
                batch = self.train_set[(epoch_idx - 1) *
                                       self.args.batch_size:epoch_idx *
                                       self.args.batch_size]
                if len(batch) < self.args.batch_size:
                    break
                ix += 1

                x, y = get_x_y(batch, self.args.dataset)
                loss, etc = self.train_iteration(x, y)

                if ix_callback is not None:
                    ix_callback(loss, etc)

                if loss == -1:
                    logging.info(f'iteration {ix}: is nan. ending')
                    ending = True
                    break

                running_loss += loss.item()
                # mag = max([torch.max(torch.abs(p.grad)) for p in self.train_params])
                # running_mag += mag

                if ix % self.log_interval == 0:
                    outs = etc['outs']

                    z = np.stack(outs).squeeze()
                    # avg of the last 50 trials
                    avg_loss = running_loss / self.args.batch_size / self.log_interval
                    test_loss, test_etc = self.test(n=30)
                    # avg_max_grad = running_mag / self.log_interval
                    log_arr = [
                        f'iteration {ix}',
                        f'train loss {avg_loss:.3f}',
                        # f'max abs grad {avg_max_grad:.3f}',
                        f'test loss {test_loss:.3f}'
                    ]
                    # calculating average index reached for goals task
                    if self.args.dset_type == 'goals':
                        avg_index = test_etc['indices'].float().mean().item()
                        log_arr.append(f'avg index {avg_index:.3f}')
                    if self.args.net == 'hypothesis':
                        ha = self.net.log_h_yes.get_input()
                        sa = self.net.log_s_yes.get_input()
                        conf = self.net.log_conf.get_input()
                        lconf, lsim, kl, lp = etc['lconf'], etc['lsim'], etc[
                            'kl'], etc['lp']
                        log_arr.append(f'hyp_app {ha:.3f}')
                        log_arr.append(f'sim_app {sa:.3f}')
                        log_arr.append(f'conf {conf:.3f}')
                        log_arr.append(f'lconf {lconf:.3f}')
                        log_arr.append(f'lsim {lsim:.3f}')
                        log_arr.append(f'lp {lp:.3f}')
                        # log_arr.append(f'kl {kl:.3f}')
                    log_str = '\t| '.join(log_arr)
                    logging.info(log_str)

                    if not self.args.no_log:
                        self.log_checkpoint(ix, etc['ins'].numpy(),
                                            etc['targets'].numpy(), z,
                                            running_loss, avg_loss)
                    running_loss = 0.0
                    running_mag = 0.0

                    # convergence based on no avg loss decrease after patience samples
                    if self.args.conv_type == 'patience':
                        if test_loss < running_min_error:
                            running_no_min = 0
                            running_min_error = test_loss
                        else:
                            running_no_min += self.log_interval
                        if running_no_min > self.args.patience:
                            logging.info(
                                f'iteration {ix}: no min for {args.patience} samples. ending'
                            )
                            ending = True
                    # elif self.args.conv_type == 'grad':
                    #     if avg_max_grad < self.args.grad_threshold:
                    #         logging.info(f'iteration {ix}: max absolute grad < {args.grad_threshold}. ending')
                    #         ending = True
                if ending:
                    break
            logging.info(f'Finished dataset epoch {e+1}')
            if ending:
                break

        if not self.args.no_log:
            # for later visualization of outputs over timesteps
            with open(
                    os.path.join(self.log.run_dir,
                                 f'checkpoints_{self.run_id}.pkl'), 'wb') as f:
                pickle.dump(self.vis_samples, f)

            self.csv_path.close()

        final_loss, etc = self.test()
        logging.info(
            f'END | iterations: {(ix // self.log_interval) * self.log_interval} | test loss: {final_loss}'
        )
        return final_loss, ix