Example #1
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Data open
    write_log(logger, "Load data...")
    gc.disable()
    with open(os.path.join(args.preprocess_path, 'processed.pkl'), 'rb') as f:
        data_ = pickle.load(f)
        train_src_indices = data_['train_src_indices']
        valid_src_indices = data_['valid_src_indices']
        train_trg_indices = data_['train_trg_indices']
        valid_trg_indices = data_['valid_trg_indices']
        src_word2id = data_['src_word2id']
        trg_word2id = data_['trg_word2id']
        src_vocab_num = len(src_word2id)
        trg_vocab_num = len(trg_word2id)
        del data_
    gc.enable()
    write_log(logger, "Finished loading data!")

    # 2) Dataloader setting
    dataset_dict = {
        'train':
        CustomDataset(train_src_indices,
                      train_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
        'valid':
        CustomDataset(valid_src_indices,
                      valid_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
    }
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers)
    }
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Train setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, 'Instantiating model...')
    model = Transformer(
        src_vocab_num=src_vocab_num,
        trg_vocab_num=trg_vocab_num,
        pad_idx=args.pad_id,
        bos_idx=args.bos_id,
        eos_idx=args.eos_id,
        d_model=args.d_model,
        d_embedding=args.d_embedding,
        n_head=args.n_head,
        dim_feedforward=args.dim_feedforward,
        num_common_layer=args.num_common_layer,
        num_encoder_layer=args.num_encoder_layer,
        num_decoder_layer=args.num_decoder_layer,
        src_max_len=args.src_max_len,
        trg_max_len=args.trg_max_len,
        dropout=args.dropout,
        embedding_dropout=args.embedding_dropout,
        trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing,
        emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing,
        parallel=args.parallel)
    model.train()
    model = model.to(device)
    tgt_mask = model.generate_square_subsequent_mask(args.trg_max_len - 1,
                                                     device)

    # 2) Optimizer & Learning rate scheduler setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    scaler = GradScaler()

    # 3) Model resume
    start_epoch = 0
    if args.resume:
        write_log(logger, 'Resume model...')
        checkpoint = torch.load(
            os.path.join(args.save_path, 'checkpoint.pth.tar'))
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        scaler.load_state_dict(checkpoint['scaler'])
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_acc = 0

    write_log(logger, 'Traing start!')

    for epoch in range(start_epoch + 1, args.num_epochs + 1):
        start_time_e = time()
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            if phase == 'valid':
                write_log(logger, 'Validation start...')
                val_loss = 0
                val_acc = 0
                model.eval()
            for i, (src, trg) in enumerate(
                    tqdm(dataloader_dict[phase],
                         bar_format='{l_bar}{bar:30}{r_bar}{bar:-2b}')):

                # Optimizer setting
                optimizer.zero_grad(set_to_none=True)

                # Input, output setting
                src = src.to(device, non_blocking=True)
                trg = trg.to(device, non_blocking=True)

                trg_sequences_target = trg[:, 1:]
                non_pad = trg_sequences_target != args.pad_id
                trg_sequences_target = trg_sequences_target[
                    non_pad].contiguous().view(-1)

                # Train
                if phase == 'train':

                    # Loss calculate
                    with autocast():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(predicted,
                                                    trg_sequences_target,
                                                    args.pad_id)

                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                    # Print loss value only training
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        acc = (predicted.max(dim=1)[1] == trg_sequences_target
                               ).sum() / len(trg_sequences_target)
                        iter_log = "[Epoch:%03d][%03d/%03d] train_loss:%03.3f | train_acc:%03.2f%% | learning_rate:%1.6f | spend_time:%02.2fmin" % \
                            (epoch, i, len(dataloader_dict['train']),
                            loss.item(), acc*100, optimizer.param_groups[0]['lr'],
                            (time() - start_time_e) / 60)
                        write_log(logger, iter_log)
                        freq = 0
                    freq += 1

                # Validation
                if phase == 'valid':
                    with torch.no_grad():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        loss = F.cross_entropy(predicted, trg_sequences_target)
                    val_loss += loss.item()
                    val_acc += (predicted.max(dim=1)[1] == trg_sequences_target
                                ).sum() / len(trg_sequences_target)
                    if args.scheduler == 'reduce_valid':
                        scheduler.step(val_loss)
                    if args.scheduler == 'lambda':
                        scheduler.step()

            if phase == 'valid':
                val_loss /= len(dataloader_dict[phase])
                val_acc /= len(dataloader_dict[phase])
                write_log(logger, 'Validation Loss: %3.3f' % val_loss)
                write_log(logger,
                          'Validation Accuracy: %3.2f%%' % (val_acc * 100))
                if val_acc > best_val_acc:
                    write_log(logger, 'Checkpoint saving...')
                    torch.save(
                        {
                            'epoch': epoch,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'scaler': scaler.state_dict()
                        }, f'checkpoint_{args.parallel}.pth.tar')
                    best_val_acc = val_acc
                    best_epoch = epoch
                else:
                    else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc.item()*100, 2)})% is better...'
                    write_log(logger, else_log)

    # 3) Print results
    print(f'Best Epoch: {best_epoch}')
    print(f'Best Accuracy: {round(best_val_acc.item(), 2)}')
Example #2
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Dataloader setting
    write_log(logger, "Load data...")
    gc.disable()
    dataset_dict = {
        'train': CustomDataset(data_path=args.preprocessed_path,
                               phase='train'),
        'valid': CustomDataset(data_path=args.preprocessed_path,
                               phase='valid'),
        'test': CustomDataset(data_path=args.preprocessed_path, phase='test')
    }
    unique_menu_count = dataset_dict['train'].unique_count()
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate()),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate()),
        'test':
        DataLoader(dataset_dict['test'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate())
    }
    gc.enable()
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Model setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, "Instantiating models...")
    model = Transformer(model_type=args.model_type,
                        input_size=unique_menu_count,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        dropout=args.dropout)
    model = model.train()
    model = model.to(device)

    # 2) Optimizer setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    criterion = nn.MSELoss()
    scaler = GradScaler(enabled=True)

    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # 2) Model resume
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(os.path.join(args.model_path,
                                             'checkpoint.pth.tar'),
                                map_location='cpu')
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        model = model.train()
        model = model.to(device)
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_rmse = 9999999

    write_log(logger, 'Train start!')

    for epoch in range(start_epoch, args.num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
                train_start_time = time.time()
                freq = 0
            elif phase == 'valid':
                model.eval()
                val_loss = 0
                val_rmse = 0

            for i, (src_menu, label_lunch,
                    label_supper) in enumerate(dataloader_dict[phase]):

                # Optimizer setting
                optimizer.zero_grad()

                # Input, output setting
                src_menu = src_menu.to(device, non_blocking=True)
                label_lunch = label_lunch.float().to(device, non_blocking=True)
                label_supper = label_supper.float().to(device,
                                                       non_blocking=True)

                # Model
                with torch.set_grad_enabled(phase == 'train'):
                    with autocast(enabled=True):
                        if args.model_type == 'sep':
                            logit = model(src_menu)
                            logit_lunch = logit[:, 0]
                            logit_supper = logit[:, 0]
                        elif args.model_type == 'total':
                            logit = model(src_menu)
                            logit_lunch = logit[:, 0]
                            logit_supper = logit[:, 1]

                    # Loss calculate
                    loss_lunch = criterion(logit_lunch, label_lunch)
                    loss_supper = criterion(logit_supper, label_supper)
                    loss = loss_lunch + loss_supper

                # Back-propagation
                if phase == 'train':
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    # Scheduler setting
                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                # Print loss value
                rmse_loss = torch.sqrt(loss)
                if phase == 'train':
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        batch_log = "[Epoch:%d][%d/%d] train_MSE_loss:%2.3f  | train_RMSE_loss:%2.3f | learning_rate:%3.6f | spend_time:%3.2fmin" \
                                % (epoch+1, i, len(dataloader_dict['train']),
                                loss.item(), rmse_loss.item(), optimizer.param_groups[0]['lr'],
                                (time.time() - train_start_time) / 60)
                        write_log(logger, batch_log)
                        freq = 0
                    freq += 1
                elif phase == 'valid':
                    val_loss += loss.item()
                    val_rmse += rmse_loss.item()

        if phase == 'valid':
            val_loss /= len(dataloader_dict['valid'])
            val_rmse /= len(dataloader_dict['valid'])
            write_log(logger, 'Validation Loss: %3.3f' % val_loss)
            write_log(logger, 'Validation RMSE: %3.3f' % val_rmse)

            if val_rmse < best_val_rmse:
                write_log(logger, 'Checkpoint saving...')
                if not os.path.exists(args.save_path):
                    os.mkdir(args.save_path)
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'scaler': scaler.state_dict()
                    }, os.path.join(args.save_path, f'checkpoint_cap.pth.tar'))
                best_val_rmse = val_rmse
                best_epoch = epoch
            else:
                else_log = f'Still {best_epoch} epoch RMSE({round(best_val_rmse, 3)}) is better...'
                write_log(logger, else_log)

    # 3)
    write_log(logger, f'Best Epoch: {best_epoch+1}')
    write_log(logger, f'Best Accuracy: {round(best_val_rmse, 3)}')
Example #3
0
    for n, p in model.named_parameters():
        if p.dim() > 1 and (n != "embedding.lut.weight"
                            and config.pretrain_emb):
            xavier_uniform_(p)
print("MODEL USED", config.model)
print("TRAINABLE PARAMETERS", count_parameters(model))

check_iter = 2000
try:
    if (config.USE_CUDA):
        model.cuda()
    model = model.train()
    best_ppl = 1000
    patient = 0
    writer = SummaryWriter(log_dir=config.save_path)
    weights_best = deepcopy(model.state_dict())
    data_iter = make_infinite(data_loader_tra)
    for n_iter in tqdm(range(1000000)):
        loss, ppl, bce, acc = model.train_one_batch(next(data_iter), n_iter)
        writer.add_scalars('loss', {'loss_train': loss}, n_iter)
        writer.add_scalars('ppl', {'ppl_train': ppl}, n_iter)
        writer.add_scalars('bce', {'bce_train': bce}, n_iter)
        writer.add_scalars('accuracy', {'acc_train': acc}, n_iter)
        if (config.noam):
            writer.add_scalars('lr', {'learning_rata': model.optimizer._rate},
                               n_iter)

        if ((n_iter + 1) % check_iter == 0):
            model = model.eval()
            model.epoch = n_iter
            model.__id__logger = 0
Example #4
0
            _, _, _ = model.train_one_batch(d)
    generate(model, val_iter, persona)


p = Personas()
# Build model, optimizer, and set states
print("Test model", config.model)
model = Transformer(p.vocab, model_file_path=config.save_path, is_eval=False)
# get persona map
filename = 'data/ConvAI2/test_persona_map'
with open(filename, 'rb') as f:
    persona_map = pickle.load(f)

#generate
iterations = 11
weights_original = deepcopy(model.state_dict())
tasks = p.get_personas('test')
for per in tqdm(tasks):
    num_of_dialog = p.get_num_of_dialog(persona=per, split='test')
    for val_dial_index in range(num_of_dialog):
        train_iter, val_iter = p.get_data_loader(persona=per,
                                                 batch_size=config.batch_size,
                                                 split='test',
                                                 fold=val_dial_index)
        persona = []
        for ppp in persona_map[per]:
            persona += ppp
        persona = list(set(persona))
        do_learning(model,
                    train_iter,
                    val_iter,
Example #5
0
        last_epoch = int(weights[-1].split('_')[-1])
        weight_path = weights[-1].replace('\\', '/')
        print('weight info of last epoch', weight_path)
        model.load_state_dict(torch.load(weight_path))
        total_epoch = last_epoch + plus_epoch
    else:
        last_epoch = 0
        total_epoch = plus_epoch

    model.train()
    for epoch in range(plus_epoch):
        epoch_loss = 0
        for iteration, data in enumerate(data_loader):
            encoder_inputs, decoder_inputs, targets = data
            optimizer.zero_grad()
            logits, _ = model(encoder_inputs, decoder_inputs)
            logits = logits.contiguous().view(-1, trg_vocab_size)
            targets = targets.contiguous().view(-1)
            loss = criterion(logits, targets)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            epoch_loss += loss
            if (iteration + 1) % 500 == 0:
                print('Epoch: %3d\t' % (last_epoch + epoch + 1),
                      'Iteration: %3d \t' % (iteration + 1),
                      'Cost: {:.5f}'.format(epoch_loss / (iteration + 1)))
        scheduler.step(epoch_loss)
    model_path = './weight/transformer_%d' % total_epoch
    torch.save(model.state_dict(), model_path)
Example #6
0
meta_batch_size = config.meta_batch_size
tasks = p.get_personas('train')
#tasks_loader = {t: p.get_data_loader(persona=t,batch_size=config.batch_size, split='train') for t in tasks}
tasks_iter = make_infinite_list(tasks)

# meta early stop
patience = 50
if config.fix_dialnum_train:
    patience = 100
best_loss = 10000000
stop_count = 0
# Main loop
for meta_iteration in range(config.epochs):
    ## save original weights to make the update
    weights_original = deepcopy(meta_net.state_dict())
    train_loss_before = []
    train_loss_meta = []
    #loss accumulate from a batch of tasks
    batch_loss = 0
    for _ in range(meta_batch_size):
        # Get task
        if config.fix_dialnum_train:
            train_iter, val_iter = p.get_balanced_loader(
                persona=tasks_iter.__next__(),
                batch_size=config.batch_size,
                split='train')
        else:
            train_iter, val_iter = p.get_data_loader(
                persona=tasks_iter.__next__(),
                batch_size=config.batch_size,
Example #7
0
class Trainer:
    def __init__(self,
                 params,
                 mode,
                 train_iter=None,
                 valid_iter=None,
                 test_iter=None):
        self.params = params

        # Train mode
        if mode == 'train':
            self.train_iter = train_iter
            self.valid_iter = valid_iter

        # Test mode
        else:
            self.test_iter = test_iter

        self.model = Transformer(self.params)
        self.model.to(self.params.device)

        # Scheduling Optimzer
        self.optimizer = ScheduledAdam(optim.Adam(self.model.parameters(),
                                                  betas=(0.9, 0.98),
                                                  eps=1e-9),
                                       hidden_dim=params.hidden_dim,
                                       warm_steps=params.warm_steps)

        self.criterion = nn.CrossEntropyLoss(ignore_index=self.params.pad_idx)
        self.criterion.to(self.params.device)

    def train(self):
        print(self.model)
        print(
            f'The model has {self.model.count_params():,} trainable parameters'
        )
        best_valid_loss = float('inf')

        for epoch in range(self.params.num_epoch):
            self.model.train()
            epoch_loss = 0
            start_time = time.time()

            for batch in self.train_iter:
                # For each batch, first zero the gradients
                self.optimizer.zero_grad()
                source = batch.kor
                target = batch.eng

                # target sentence consists of <sos> and following tokens (except the <eos> token)
                output = self.model(source, target[:, :-1])[0]

                # ground truth sentence consists of tokens and <eos> token (except the <sos> token)
                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)
                # output = [(batch size * target length - 1), output dim]
                # target = [(batch size * target length - 1)]
                loss = self.criterion(output, target)
                loss.backward()

                # clip the gradients to prevent the model from exploding gradient
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.params.clip)

                self.optimizer.step()

                # 'item' method is used to extract a scalar from a tensor which only contains a single value.
                epoch_loss += loss.item()

            train_loss = epoch_loss / len(self.train_iter)
            valid_loss = self.evaluate()

            end_time = time.time()
            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(self.model.state_dict(), self.params.save_model)

            print(
                f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s'
            )
            print(
                f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}'
            )

    def evaluate(self):
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for batch in self.valid_iter:
                source = batch.kor
                target = batch.eng

                output = self.model(source, target[:, :-1])[0]

                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, target)

                epoch_loss += loss.item()

        return epoch_loss / len(self.valid_iter)

    def inference(self):
        self.model.load_state_dict(torch.load(self.params.save_model))
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for batch in self.test_iter:
                source = batch.kor
                target = batch.eng

                output = self.model(source, target[:, :-1])[0]

                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, target)

                epoch_loss += loss.item()

        test_loss = epoch_loss / len(self.test_iter)
        print(f'Test Loss: {test_loss:.3f}')