コード例 #1
0
def train(rank, args):
    print(f"Running basic DDP example on rank {rank} {args.master_port}.")
    setup(rank, args.world_size, args.master_port)
    args.local_rank = rank
    torch.manual_seed(args.seed)
    torch.cuda.set_device(rank)
    src_vocab = Dictionary.read_vocab(args.vocab_src)
    tgt_vocab = Dictionary.read_vocab(args.vocab_tgt)
    batch_size = args.batch_size

    # model init
    model = TransformerModel(d_model=args.d_model,
                             nhead=args.nhead,
                             num_encoder_layers=args.num_encoder_layers,
                             num_decoder_layers=args.num_decoder_layers,
                             dropout=args.dropout,
                             attention_dropout=args.attn_dropout,
                             src_dictionary=src_vocab,
                             tgt_dictionary=tgt_vocab)
    model.to(rank)
    model = DDP(model, device_ids=[rank])

    if rank == 0:
        print(model)
    print('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # data load
    train_loader = dataloader.get_train_parallel_loader(args.train_src,
                                                        args.train_tgt,
                                                        src_vocab,
                                                        tgt_vocab,
                                                        batch_size=batch_size)
    valid_loader = dataloader.get_valid_parallel_loader(args.valid_src,
                                                        args.valid_tgt,
                                                        src_vocab,
                                                        tgt_vocab,
                                                        batch_size=batch_size)

    data = {'dataloader': {'train': train_loader, 'valid': valid_loader}}

    trainer = Trainer(model, data, args)
    for epoch in range(1, args.max_epoch):
        trainer.mt_step(epoch)
        trainer.evaluate(epoch)
        trainer.save_checkpoint(epoch)
コード例 #2
0
ファイル: train_ar.py プロジェクト: nyu-dl/dl4chem-mgm
def main(params):
    # setup random seeds
    set_seed(params.seed)
    params.ar = True

    exp_path = os.path.join(params.dump_path, params.exp_name)
    # create exp path if it doesn't exist
    if not os.path.exists(exp_path):
        os.makedirs(exp_path)
    # create logger
    logger = create_logger(os.path.join(exp_path, 'train.log'), 0)
    logger.info("============ Initialized logger ============")
    logger.info("Random seed is {}".format(params.seed))
    logger.info("\n".join("%s: %s" % (k, str(v))
                          for k, v in sorted(dict(vars(params)).items())))
    logger.info("The experiment will be stored in %s\n" % exp_path)
    logger.info("Running command: %s" % 'python ' + ' '.join(sys.argv))
    logger.info("")
    # load data
    data, loader = load_smiles_data(params)
    if params.data_type == 'ChEMBL':
        all_smiles_mols = open(os.path.join(params.data_path, 'guacamol_v1_all.smiles'), 'r').readlines()
    else:
        all_smiles_mols = open(os.path.join(params.data_path, 'QM9_all.smiles'), 'r').readlines()
    train_data, val_data = data['train'], data['valid']
    dico = data['dico']
    logger.info ('train_data len is {}'.format(len(train_data)))
    logger.info ('val_data len is {}'.format(len(val_data)))

    # keep cycling through train_loader forever
    # stop when max iters is reached
    def rcycle(iterable):
        saved = []                 # In-memory cache
        for element in iterable:
            yield element
            saved.append(element)
        while saved:
            random.shuffle(saved)  # Shuffle every batch
            for element in saved:
                  yield element
    train_loader = rcycle(train_data.get_iterator(shuffle=True, group_by_size=True, n_sentences=-1))

    # extra param names for transformermodel
    params.n_langs = 1
    # build Transformer model
    model = TransformerModel(params, is_encoder=False, with_output=True)

    if params.local_cpu is False:
        model = model.cuda()
    opt = get_optimizer(model.parameters(), params.optimizer)
    scores = {'ppl': np.float('inf'), 'acc': 0}

    if params.load_path:
        reloaded_iter, scores = load_model(params, model, opt, logger)

    for total_iter, train_batch in enumerate(train_loader):
        if params.load_path is not None:
            total_iter += reloaded_iter + 1

        epoch = total_iter // params.epoch_size
        if total_iter == params.max_steps:
            logger.info("============ Done training ... ============")
            break
        elif total_iter % params.epoch_size == 0:
            logger.info("============ Starting epoch %i ... ============" % epoch)
        model.train()
        opt.zero_grad()
        train_loss = calculate_loss(model, train_batch, params)
        train_loss.backward()
        if params.clip_grad_norm > 0:
            clip_grad_norm_(model.parameters(), params.clip_grad_norm)
        opt.step()
        if total_iter % params.print_after == 0:
            logger.info("Step {} ; Loss = {}".format(total_iter, train_loss))

        if total_iter > 0 and total_iter % params.epoch_size == (params.epoch_size - 1):
            # run eval step (calculate validation loss)
            model.eval()
            n_chars = 0
            xe_loss = 0
            n_valid = 0
            logger.info("============ Evaluating ... ============")
            val_loader = val_data.get_iterator(shuffle=True)
            for val_iter, val_batch in enumerate(val_loader):
                with torch.no_grad():
                    val_scores, val_loss, val_y = calculate_loss(model, val_batch, params, get_scores=True)
                # update stats
                n_chars += val_y.size(0)
                xe_loss += val_loss.item() * len(val_y)
                n_valid += (val_scores.max(1)[1] == val_y).sum().item()

            ppl = np.exp(xe_loss / n_chars)
            acc = 100. * n_valid / n_chars
            logger.info("Acc={}, PPL={}".format(acc, ppl))
            if acc > scores['acc']:
                scores['acc'] = acc
                scores['ppl'] = ppl
                save_model(params, data, model, opt, dico, logger, 'best_model', epoch, total_iter, scores)
                logger.info('Saving new best_model {}'.format(epoch))
                logger.info("Best Acc={}, PPL={}".format(scores['acc'], scores['ppl']))

            logger.info("============ Generating ... ============")
            number_samples = 100
            gen_smiles = generate_smiles(params, model, dico, number_samples)
            generator = ARMockGenerator(gen_smiles)

            try:
                benchmark = ValidityBenchmark(number_samples=number_samples)
                validity_score = benchmark.assess_model(generator).score
            except:
                validity_score = -1
            try:
                benchmark = UniquenessBenchmark(number_samples=number_samples)
                uniqueness_score = benchmark.assess_model(generator).score
            except:
                uniqueness_score = -1

            try:
                benchmark = KLDivBenchmark(number_samples=number_samples, training_set=all_smiles_mols)
                kldiv_score = benchmark.assess_model(generator).score
            except:
                kldiv_score = -1
            logger.info('Validity Score={}, Uniqueness Score={}, KlDiv Score={}'.format(validity_score, uniqueness_score, kldiv_score))
            save_model(params, data, model, opt, dico, logger, 'model', epoch, total_iter, {'ppl': ppl, 'acc': acc})