示例#1
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(self.vocab,
                               config.train_data_path,
                               config.batch_size,
                               single_pass=False,
                               mode='train')
        time.sleep(10)

        train_dir = os.path.join(config.log_root,
                                 'train_%d' % (int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'models')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iter):
        model_state_dict = self.model.state_dict()

        state = {
            'iter': iter,
            'current_loss': running_avg_loss,
            'optimizer': self.optimizer._optimizer.state_dict(),
            "model": model_state_dict
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_%d_%d' % (iter, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_path):

        device = torch.device('cuda' if use_cuda else 'cpu')

        self.model = Model(config.vocab_size,
                           config.vocab_size,
                           config.max_enc_steps,
                           config.max_dec_steps,
                           d_k=config.d_k,
                           d_v=config.d_v,
                           d_model=config.d_model,
                           d_word_vec=config.emb_dim,
                           d_inner=config.d_inner_hid,
                           n_layers=config.n_layers,
                           n_head=config.n_head,
                           dropout=config.dropout).to(device)

        self.optimizer = ScheduledOptim(
            optim.Adam(filter(lambda x: x.requires_grad,
                              self.model.parameters()),
                       betas=(0.9, 0.98),
                       eps=1e-09), config.d_model, config.n_warmup_steps)

        params = list(self.model.encoder.parameters()) + list(
            self.model.decoder.parameters())
        total_params = sum([param[0].nelement() for param in params])
        print('The Number of params of model: %.3f million' %
              (total_params / 1e6))  # million

        start_iter, start_loss = 0, 0

        if model_path is not None:
            state = torch.load(model_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer._optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer._optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):
        enc_batch, enc_lens, enc_pos, enc_padding_mask, enc_batch_extend_vocab, \
        extra_zeros, c_t, coverage = get_input_from_batch(batch, use_cuda, transformer=True)
        dec_batch, dec_lens, dec_pos, dec_padding_mask, max_dec_len, tgt_batch = \
            get_output_from_batch(batch, use_cuda, transformer=True)

        self.optimizer.zero_grad()

        pred = self.model(enc_batch, enc_pos, dec_batch, dec_pos)
        gold_probs = torch.gather(pred, -1, tgt_batch.unsqueeze(-1)).squeeze()
        batch_loss = -torch.log(gold_probs + config.eps)
        batch_loss = batch_loss * dec_padding_mask

        sum_losses = torch.sum(batch_loss, 1)
        batch_avg_loss = sum_losses / dec_lens
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        # update parameters
        self.optimizer.step_and_update_lr()

        return loss.item(), 0.

    def run(self, n_iters, model_path=None):
        iter, running_avg_loss = self.setup_train(model_path)
        start = time.time()
        interval = 100

        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss, cove_loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % interval == 0:
                self.summary_writer.flush()
                print('step: %d, second: %.2f , loss: %f, cover_loss: %f' %
                      (iter, time.time() - start, loss, cove_loss))
                start = time.time()
            if iter % 5000 == 0:
                self.save_model(running_avg_loss, iter)
示例#2
0
class LocalGenTrainer:
    def __init__(self, writer, model, device, args):
        self.model = model
        self.energy_fn = LocalEnergyCE(model, args)

        self.device = device

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(),
                          lr=args.lr,
                          betas=args.betas,
                          weight_decay=args.weight_decay)
        # self.optim = SGD(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(
            self.optim,
            init_lr=args.lr,
            n_warmup_steps=args.n_warmup_steps,
            steps_decay_scale=args.steps_decay_scale)

        self.log_freq = args.log_interval
        self.writer = writer

        print("Total Parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

    def step(self, data):
        seq, coords, start_id, res_counts = data

        seq = seq.to(self.device)  # (N, L)
        coords = coords.to(self.device)  # (N, L, 3)
        start_id = start_id.to(self.device)  # (N, L)
        res_counts = res_counts.to(self.device)  # (N, 3)

        loss_r, loss_angle, loss_profile, loss_start_id, loss_res_counts = self.energy_fn.forward(
            seq, coords, start_id, res_counts)
        return loss_r, loss_angle, loss_profile, loss_start_id, loss_res_counts

    def train(self, epoch, data_loader, flag='Train'):
        for i, data in tqdm(enumerate(data_loader)):
            loss_r, loss_angle, loss_profile, loss_start_id, loss_res_counts = self.step(
                data)
            loss = loss_r + loss_angle + loss_profile + loss_start_id + loss_res_counts

            if flag == 'Train':
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            len_data_loader = len(data_loader)
            if flag == 'Train':
                log_freq = self.log_freq
            else:
                log_freq = 1
            if i % log_freq == 0:
                self.writer.add_scalar(f'{flag}/profile_loss',
                                       loss_profile.item(),
                                       epoch * len_data_loader + i)
                self.writer.add_scalar(f'{flag}/coords_radius_loss',
                                       loss_r.item(),
                                       epoch * len_data_loader + i)
                self.writer.add_scalar(f'{flag}/coords_angle_loss',
                                       loss_angle.item(),
                                       epoch * len_data_loader + i)
                self.writer.add_scalar(f'{flag}/start_id_loss',
                                       loss_start_id.item(),
                                       epoch * len_data_loader + i)
                self.writer.add_scalar(f'{flag}/res_counts_loss',
                                       loss_res_counts.item(),
                                       epoch * len_data_loader + i)
                self.writer.add_scalar(f'{flag}/total_loss', loss.item(),
                                       epoch * len_data_loader + i)

                print(f'{flag} epoch {epoch} Iter: {i} '
                      f'profile_loss: {loss_profile.item():.3f} '
                      f'coords_radius_loss: {loss_r.item():.3f} '
                      f'coords_angle_loss: {loss_angle.item():.3f} '
                      f'start_id_loss: {loss_start_id.item():.3f} '
                      f'res_counts_loss: {loss_res_counts.item():.3f} '
                      f'total_loss: {loss.item():.3f} ')

    def test(self, epoch, data_loader, flag='Test'):
        self.model.eval()
        torch.set_grad_enabled(False)

        self.train(epoch, data_loader, flag=flag)

        self.model.train()
        torch.set_grad_enabled(True)
示例#3
0
def train(args, model, train_iter, eval_iter=None):

    if args.use_cuda:
        model = model.cuda(args.device_no)
        model.train()
    # model = torch.nn.DataParallel(model, device_ids=(0,1,2))
    # train_data = load_data('train.txt')

    optimizer = ScheduledOptim(
        optim.Adam(filter(lambda x: x.requires_grad, model.parameters())),
        args.learning_rate, args.warmup_steps)

    loss_list = []
    eval_loss_list = []
    # with torch.cuda.device(device_num):
    batch_count = 0
    running_loss = 0
    # start = time.time()
    while batch_count < args.training_steps:
        for inputs, targets in train_iter:
            if batch_count >= args.training_steps:
                break
            # input is a masked sequence
            # target contains original word on the masked position, other positions are filled with -1
            # e.g.
            # input:  [101, 2342, 6537, 104,   104,  4423]
            # target: [-1,  -1,   -1,   10281, 8213, -1]

            logger.debug(f'inputs: {inputs}')
            logger.debug(f'targets: {targets}')
            batch_count += 1
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs, labels=inputs)
            loss = outputs[0]
            loss.backward()
            optimizer.step()

            logger.debug(f'model outputs: {outputs}')

            running_loss += loss.item()

            # write loss log
            if batch_count % args.log_interval == 0 or \
                        (batch_count < args.warmup_steps and batch_count % int(args.log_interval / 10) == 0):
                if batch_count <= args.warmup_steps:
                    loss_list.append(running_loss / args.log_interval * 10)
                    logger.info('Batch:%6d, loss: %.6f  [%s]' % \
                            (batch_count, running_loss/args.log_interval*10, time.strftime("%D %H:%M:%S")))
                else:
                    loss_list.append(running_loss / args.log_interval)
                    logger.info('Batch:%6d, loss: %.6f  [%s]' % \
                            (batch_count, running_loss/args.log_interval, time.strftime("%D %H:%M:%S")))
                running_loss = 0

            # save model & curve
            if batch_count % args.checkpoint_interval == 0:
                if eval_iter is not None:
                    eval_loss = eval(args, model, eval_iter)
                    eval_loss_list.append(eval_loss)
                    if eval_loss <= min(
                            eval_loss_list) and args.save_best_checkpoint:
                        path = os.path.join(args.checkpoint_save_path, "model",
                                            f"{args.model_type}-best")
                        if not os.path.exists(path):
                            os.makedirs(path)
                        model.save_pretrained(path)
                        logger.info('Best model saved in %s' % path)
                if args.save_normal_checkpoint:
                    path = os.path.join(args.checkpoint_save_path, "tmp",
                                        f"{args.model_type}-{batch_count}")
                    if not os.path.exists(path):
                        os.makedirs(path)
                    model.save_pretrained(path)
                    logger.info('Model saved in %s' % path)
                    curve_info = {
                        "train_loss_list": loss_list,
                        "eval_loss_list": eval_loss_list
                    }
                    with open(
                            path +
                            f'/{args.model_type}-{batch_count}-loss.pkl',
                            'wb+') as file:
                        pickle.dump(curve_info, file)
    return loss_list