Ejemplo n.º 1
0
Archivo: train.py Proyecto: Jwoo5/temp
    train_dataset = MITBIHDataset(train_path)
    test_dataset = MITBIHDataset(test_path)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch, sampler = None, shuffle = True, num_workers = args.num_workers)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = args.batch, sampler = None, shuffle = True, num_workers = args.num_workers)

    model = Transformer(config.n_layer, config.d_model, config.n_head, config.d_head, config.d_ff, config.n_classes, config.dropout)

    criterion_cls = torch.nn.CrossEntropyLoss()

    t_total = len(train_loader) * args.epoch
    
    lr = args.lr
    params = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{'params' : [value], 'lr' : lr * 2, \
                            'weight_decay' : 0 }]
        else:
            params += [{'params':[value],'lr':lr, 'weight_decay': 0.0005}]
    
    lr = lr * 0.1
    optimizer = torch.optim.Adam(params)
    model.to(config.device)
    wandb.watch(model)    

    inputs = torch.FloatTensor(1)
    label = torch.LongTensor(1)
    inputs = inputs.to(config.device)
Ejemplo n.º 2
0
                            is_eval=True)
    elif (config.model == "experts"):
        model = Transformer_experts(vocab,
                                    decoder_number=program_number,
                                    model_file_path=config.save_path,
                                    is_eval=True)
    if (config.USE_CUDA):
        model.cuda()
    model = model.eval()
    loss_test, ppl_test, bce_test, acc_test, bleu_score_g, bleu_score_b = evaluate(
        model, data_loader_tst, ty="test", max_dec_step=50)
    exit(0)

if (config.model == "trs"):
    model = Transformer(vocab, decoder_number=program_number)
    for n, p in model.named_parameters():
        if p.dim() > 1 and (n != "embedding.lut.weight"
                            and config.pretrain_emb):
            xavier_uniform_(p)
elif (config.model == "experts"):
    model = Transformer_experts(vocab, decoder_number=program_number)
    for n, p in model.named_parameters():
        if p.dim() > 1 and (n != "embedding.lut.weight"
                            and config.pretrain_emb):
            xavier_uniform_(p)
print("MODEL USED", config.model)
print("TRAINABLE PARAMETERS", count_parameters(model))

check_iter = 2000
try:
    if (config.USE_CUDA):
Ejemplo n.º 3
0
def main(args):
    comm = MPI.COMM_WORLD
    world_size = comm.Get_size()
    rank = comm.Get_rank()
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(args.master_port)
    torch.cuda.set_device(rank)
    dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
    device = torch.device("cuda")

    logger = None
    tb_logger = None
    if rank == 0:
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
        if not os.path.exists(args.tensorboard_log_dir):
            os.mkdir(args.tensorboard_log_dir)
        tb_logger = SummaryWriter(
            f"{args.tensorboard_log_dir}/{args.model_name}")

        logger = logging.getLogger(__name__)
        logger.setLevel(logging.DEBUG)
        handler = TqdmLoggingHandler()
        handler.setFormatter(logging.Formatter(" %(asctime)s - %(message)s"))
        logger.addHandler(handler)
        logger.propagate = False

    write_log(logger, "Load data")

    def load_data(args):
        gc.disable()
        with open(f"{args.preprocessed_data_path}/hanja_korean_word2id.pkl",
                  "rb") as f:
            data = pickle.load(f)
            hanja_word2id = data['hanja_word2id']
            korean_word2id = data['korean_word2id']

        with open(f"{args.preprocessed_data_path}/preprocessed_train.pkl",
                  "rb") as f:
            data = pickle.load(f)
            train_hanja_indices = data['hanja_indices']
            train_korean_indices = data['korean_indices']
            train_additional_hanja_indices = data['additional_hanja_indices']

        with open(f"{args.preprocessed_data_path}/preprocessed_valid.pkl",
                  "rb") as f:
            data = pickle.load(f)
            valid_hanja_indices = data['hanja_indices']
            valid_korean_indices = data['korean_indices']
            valid_additional_hanja_indices = data['additional_hanja_indices']

        gc.enable()
        write_log(logger, "Finished loading data!")
        return (hanja_word2id, korean_word2id, train_hanja_indices,
                train_korean_indices, train_additional_hanja_indices,
                valid_hanja_indices, valid_korean_indices,
                valid_additional_hanja_indices)

    # load data
    (hanja_word2id, korean_word2id, train_hanja_indices, train_korean_indices,
     train_additional_hanja_indices, valid_hanja_indices, valid_korean_indices,
     valid_additional_hanja_indices) = load_data(args)
    hanja_vocab_num = len(hanja_word2id)
    korean_vocab_num = len(korean_word2id)

    hk_dataset = HanjaKoreanDataset(train_hanja_indices,
                                    train_korean_indices,
                                    min_len=args.min_len,
                                    src_max_len=args.src_max_len,
                                    trg_max_len=args.trg_max_len)
    hk_sampler = DistributedSampler(hk_dataset,
                                    num_replicas=world_size,
                                    rank=rank)
    hk_loader = DataLoader(hk_dataset,
                           drop_last=True,
                           batch_size=args.hk_batch_size,
                           sampler=hk_sampler,
                           num_workers=args.num_workers,
                           prefetch_factor=4,
                           pin_memory=True)
    write_log(logger, f"hanja-korean: {len(hk_dataset)}, {len(hk_loader)}")

    h_dataset = HanjaDataset(train_hanja_indices,
                             train_additional_hanja_indices,
                             hanja_word2id,
                             min_len=args.min_len,
                             src_max_len=args.src_max_len)
    h_sampler = DistributedSampler(h_dataset,
                                   num_replicas=world_size,
                                   rank=rank)
    h_loader = DataLoader(h_dataset,
                          drop_last=True,
                          batch_size=args.h_batch_size,
                          sampler=h_sampler,
                          num_workers=args.num_workers,
                          prefetch_factor=4,
                          pin_memory=True)
    write_log(logger, f"hanja: {len(h_dataset)}, {len(h_loader)}")

    hk_valid_dataset = HanjaKoreanDataset(valid_hanja_indices,
                                          valid_korean_indices,
                                          min_len=args.min_len,
                                          src_max_len=args.src_max_len,
                                          trg_max_len=args.trg_max_len)
    hk_valid_sampler = DistributedSampler(hk_valid_dataset,
                                          num_replicas=world_size,
                                          rank=rank)
    hk_valid_loader = DataLoader(hk_valid_dataset,
                                 drop_last=True,
                                 batch_size=args.hk_batch_size,
                                 sampler=hk_valid_sampler)
    write_log(
        logger,
        f"hanja-korean-valid: {len(hk_valid_dataset)}, {len(hk_valid_loader)}")

    h_valid_dataset = HanjaDataset(valid_hanja_indices,
                                   valid_additional_hanja_indices,
                                   hanja_word2id,
                                   min_len=args.min_len,
                                   src_max_len=args.src_max_len)
    h_valid_sampler = DistributedSampler(h_valid_dataset,
                                         num_replicas=world_size,
                                         rank=rank)
    h_valid_loader = DataLoader(h_valid_dataset,
                                drop_last=True,
                                batch_size=args.h_batch_size,
                                sampler=h_valid_sampler)
    write_log(logger, f"hanja: {len(h_valid_dataset)}, {len(h_valid_loader)}")

    del (train_hanja_indices, train_korean_indices,
         train_additional_hanja_indices, valid_hanja_indices,
         valid_korean_indices, valid_additional_hanja_indices)

    write_log(logger, "Build model")
    model = Transformer(hanja_vocab_num,
                        korean_vocab_num,
                        pad_idx=args.pad_idx,
                        bos_idx=args.bos_idx,
                        eos_idx=args.eos_idx,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dropout=args.dropout,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer).to(device)
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[device],
                                                find_unused_parameters=True)
    for param in model.parameters():
        dist.broadcast(param.data, 0)

    dist.barrier()
    write_log(
        logger,
        f"Total Parameters: {sum([p.nelement() for p in model.parameters()])}")

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = Ralamb(params=optimizer_grouped_parameters, lr=args.lr)

    total_iters = round(
        len(hk_loader) / args.num_grad_accumulate * args.epochs)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, round(total_iters * args.warmup_ratio), total_iters)
    scaler = GradScaler()

    start_epoch = 0
    if args.resume:

        def load_states():
            checkpoint = torch.load(
                f'{args.save_path}/{args.model_name}_ckpt.pt',
                map_location='cpu')
            start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            scaler.load_state_dict(checkpoint['scaler'])
            return start_epoch

        start_epoch = load_states()

    write_log(logger, f"Training start - Total iter: {total_iters}\n")
    iter_num = round(len(hk_loader) / args.num_grad_accumulate)
    global_step = start_epoch * iter_num
    hk_iter = iter(hk_loader)
    h_iter = iter(h_loader)
    model.train()
    tgt_mask = Transformer.generate_square_subsequent_mask(
        args.trg_max_len - 1, device)

    # validation
    validate(model, tgt_mask, h_valid_loader, hk_valid_loader, rank, logger,
             tb_logger, 0, device)

    for epoch in range(start_epoch + 1, args.epochs + 1):
        while True:
            start = time.time()
            finish_epoch = False
            trans_top5, trans_loss, mask_top5, mask_loss = 0.0, 0.0, 0.0, 0.0

            if args.train_reconstruct:
                optimizer.zero_grad(set_to_none=True)
                for _ in range(args.num_grad_accumulate):
                    try:
                        src_sequences, trg_sequences = next(h_iter)
                    except StopIteration:
                        h_sampler.set_epoch(epoch)
                        h_iter = iter(h_loader)
                        src_sequences, trg_sequences = next(h_iter)

                    trg_sequences = trg_sequences.to(device)
                    src_sequences = src_sequences.to(device)
                    non_pad = trg_sequences != args.pad_idx
                    trg_sequences = trg_sequences[non_pad].contiguous().view(
                        -1)

                    with autocast():
                        predicted = model.module.reconstruct_predict(
                            src_sequences, masked_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(
                            predicted,
                            trg_sequences) / args.num_grad_accumulate

                    scaler.scale(loss).backward()

                    if global_step % args.print_freq == 0:
                        mask_top5 += accuracy(predicted, trg_sequences,
                                              5) / args.num_grad_accumulate
                        mask_loss += loss.detach().item()

                for param in model.parameters():
                    if param.grad is not None:
                        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                        param.grad.data = param.grad.data / world_size

                scaler.step(optimizer)
                scaler.update()

            if args.train_translate:
                optimizer.zero_grad(set_to_none=True)
                for _ in range(args.num_grad_accumulate):
                    try:
                        src_sequences, trg_sequences = next(hk_iter)
                    except StopIteration:
                        hk_sampler.set_epoch(epoch)
                        hk_iter = iter(hk_loader)
                        src_sequences, trg_sequences = next(hk_iter)
                        finish_epoch = True

                    trg_sequences = trg_sequences.to(device)
                    trg_sequences_target = trg_sequences[:, 1:]
                    src_sequences = src_sequences.to(device)
                    non_pad = trg_sequences_target != args.pad_idx
                    trg_sequences_target = trg_sequences_target[
                        non_pad].contiguous().view(-1)

                    with autocast():
                        predicted = model(src_sequences,
                                          trg_sequences[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(
                            predicted,
                            trg_sequences_target) / args.num_grad_accumulate

                    scaler.scale(loss).backward()

                    if global_step % args.print_freq == 0:
                        trans_top5 += accuracy(predicted, trg_sequences_target,
                                               5) / args.num_grad_accumulate
                        trans_loss += loss.detach().item()

                for param in model.parameters():
                    if param.grad is not None:
                        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                        param.grad.data = param.grad.data / world_size

                scaler.step(optimizer)
                scaler.update()

            scheduler.step()

            # Print status
            if global_step % args.print_freq == 0:
                if args.train_reconstruct:
                    mask_top5 = torch.cuda.FloatTensor([mask_top5])
                    mask_loss = torch.cuda.FloatTensor([mask_loss])
                    dist.all_reduce(mask_top5, op=dist.ReduceOp.SUM)
                    dist.all_reduce(mask_loss, op=dist.ReduceOp.SUM)
                    mask_top5 = (mask_top5 / world_size).item()
                    mask_loss = (mask_loss / world_size).item()

                if args.train_translate:
                    trans_top5 = torch.cuda.FloatTensor([trans_top5])
                    trans_loss = torch.cuda.FloatTensor([trans_loss])
                    dist.all_reduce(trans_top5, op=dist.ReduceOp.SUM)
                    dist.all_reduce(trans_loss, op=dist.ReduceOp.SUM)
                    trans_top5 = (trans_top5 / world_size).item()
                    trans_loss = (trans_loss / world_size).item()

                if rank == 0:
                    batch_time = time.time() - start
                    write_log(
                        logger,
                        f'[{global_step}/{total_iters}, {epoch}]\tIter time: {batch_time:.3f}\t'
                        f'Trans loss: {trans_loss:.3f}\tMask_loss: {mask_loss:.3f}\t'
                        f'Trans@5: {trans_top5:.3f}\tMask@5: {mask_top5:.3f}')

                    tb_logger.add_scalar('loss/translate', trans_loss,
                                         global_step)
                    tb_logger.add_scalar('loss/mask', mask_loss, global_step)
                    tb_logger.add_scalar('top5/translate', trans_top5,
                                         global_step)
                    tb_logger.add_scalar('top5/mask', mask_top5, global_step)
                    tb_logger.add_scalar('batch/time', batch_time, global_step)
                    tb_logger.add_scalar('batch/lr',
                                         optimizer.param_groups[0]['lr'],
                                         global_step)

            global_step += 1
            if finish_epoch:
                break

        # validation
        validate(model, tgt_mask, h_valid_loader, hk_valid_loader, rank,
                 logger, tb_logger, epoch, device)
        # save model
        if rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'scaler': scaler.state_dict()
                }, f'{args.save_path}/{args.model_name}_ckpt.pt')
            write_log(logger, f"***** {epoch}th model updated! *****")