Esempio n. 1
0
    def get_optimizer(self,
                      lr,
                      t_total,
                      schedule_type='warmup_linear',
                      optimizer_type='lamb'):

        # Prepare optimiser and schedule
        no_decay = ['bias', 'LayerNorm.weight']

        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            self.weight_decay
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        if optimizer_type == 'lamb':
            optimizer = Lamb(optimizer_grouped_parameters,
                             lr=lr,
                             eps=self.adam_epsilon)
        elif optimizer_type == 'adamw':
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=lr,
                              eps=self.adam_epsilon)

        schedule_class = SCHEDULES[schedule_type]

        scheduler = schedule_class(optimizer,
                                   warmup_steps=self.warmup_steps,
                                   t_total=t_total)

        return optimizer, scheduler
Esempio n. 2
0
# At any point you can hit Ctrl + C to break out of training early.
try:
    optimizer = None
    # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax)
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay)
    if args.optimizer == 'adagrad':
        optimizer = torch.optim.Adagrad(params, lr=args.lr, weight_decay=args.wdecay)
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay)
    if args.optimizer == 'adamw':
        optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wdecay)
    if args.optimizer == 'lamb':
        from pytorch_lamb import Lamb
        optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0.25)
        #optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0.1)
        #optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0, random_min_trust=0.2, random_trust_dice=10)
        #optimizer = Lamb(params, lr=args.lr, weight_decay=args.wdecay, min_trust=0.2, random_min_trust=0.5, random_trust_dice=4)
    from lookahead import Lookahead
    if False:
        k, alpha = 5, 0.8
        print('Lookahead - k {} and alpha {}'.format(k, alpha))
        optimizer = Lookahead(base_optimizer=optimizer, k=k, alpha=alpha)

    from apex import amp
    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    #model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    for epoch in range(1, args.epochs+1):
        epoch_start_time = time.time()
Esempio n. 3
0
def train(args):
    """Train E2E VC model."""
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())

    # In TTS, this is reversed, but not in VC. See `espnet.utils.training.batchfy`
    idim = int(valid_json[utts[0]]["input"][0]["shape"][1])
    odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    logging.info("#input dims : " + str(idim))
    logging.info("#output dims: " + str(odim))

    # get extra input and output dimenstion
    if args.use_speaker_embedding:
        args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0])
    else:
        args.spk_embed_dim = None
    if args.use_second_target:
        args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1])
    else:
        args.spc_dim = None

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to" + model_conf)
        f.write(
            json.dumps(
                (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
            ).encode("utf_8")
        )
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    # specify model architecture
    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args, TTSInterface)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, TTSInterface)
    logging.info(model)
    reporter = model.reporter

    # freeze modules, if specified
    if args.freeze_mods:
        for mod, param in model.named_parameters():
            if any(mod.startswith(key) for key in args.freeze_mods):
                logging.info("freezing %s" % mod)
                param.requires_grad = False

    for mod, param in model.named_parameters():
        if not param.requires_grad:
            logging.info("Frozen module %s" % mod)

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)"
                % (args.batch_size, args.batch_size * args.ngpu)
            )
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    logging.warning(
        "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad),
            sum(p.numel() for p in model.parameters() if p.requires_grad)
            * 100.0
            / sum(p.numel() for p in model.parameters()),
        )
    )

    # Setup an optimizer
    if args.opt == "adam":
        optimizer = torch.optim.Adam(
            model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay
        )
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model, args.adim, args.transformer_warmup_steps, args.transformer_lr
        )
    elif args.opt == "lamb":
        from pytorch_lamb import Lamb

        optimizer = Lamb(
            model.parameters(), lr=args.lr, weight_decay=0.01, betas=(0.9, 0.999)
        )
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    if use_sortagrad:
        args.batch_sort_key = "input"
    # make minibatch list (variable length)
    train_batchset = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=False,
        iaxis=0,
        oaxis=0,
    )
    valid_batchset = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        batch_sort_key=args.batch_sort_key,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        swap_io=False,
        iaxis=0,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(
        mode="vc",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    load_cv = LoadInputsAndTargets(
        mode="vc",
        use_speaker_embedding=args.use_speaker_embedding,
        use_second_target=args.use_second_target,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
        keep_all_data_on_mem=args.keep_all_data_on_mem,
    )

    converter = CustomConverter()
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    train_iter = {
        "main": ChainerDataLoader(
            dataset=TransformDataset(
                train_batchset, lambda data: converter([load_tr(data)])
            ),
            batch_size=1,
            num_workers=args.num_iter_processes,
            shuffle=not use_sortagrad,
            collate_fn=lambda x: x[0],
        )
    }
    valid_iter = {
        "main": ChainerDataLoader(
            dataset=TransformDataset(
                valid_batchset, lambda data: converter([load_cv(data)])
            ),
            batch_size=1,
            shuffle=False,
            collate_fn=lambda x: x[0],
            num_workers=args.num_iter_processes,
        )
    }

    # Set up a trainer
    updater = CustomUpdater(
        model, args.grad_clip, train_iter, optimizer, device, args.accum_grad
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # set intervals
    eval_interval = (args.eval_interval_epochs, "epoch")
    save_interval = (args.save_interval_epochs, "epoch")
    report_interval = (args.report_interval_iters, "iteration")

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval
    )

    # Save snapshot for each epoch
    trainer.extend(torch_snapshot(), trigger=save_interval)

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger(
            "validation/main/loss", trigger=eval_interval
        ),
    )

    # Save attention figure for each epoch
    if args.num_save_attention > 0:
        data = sorted(
            list(valid_json.items())[: args.num_save_attention],
            key=lambda x: int(x[1]["input"][0]["shape"][1]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
            reverse=True,
        )
        trainer.extend(att_reporter, trigger=eval_interval)
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if hasattr(model, "module"):
        base_plot_keys = model.module.base_plot_keys
    else:
        base_plot_keys = model.base_plot_keys
    plot_keys = []
    for key in base_plot_keys:
        plot_key = ["main/" + key, "validation/main/" + key]
        trainer.extend(
            extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"),
            trigger=eval_interval,
        )
        plot_keys += plot_key
    trainer.extend(
        extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"),
        trigger=eval_interval,
    )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=report_interval))
    report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys
    trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval)
    trainer.extend(extensions.ProgressBar(), trigger=report_interval)

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        from torch.utils.tensorboard import SummaryWriter

        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
        )

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Esempio n. 4
0
def main():
    global global_token_count, event_writer, train_step, train_loss, last_log_step, \
        best_val_loss, epoch, model

    if args.local_rank > 0:
        pass  # skip shutdown when rank is explicitly set + not zero rank
    else:
        os.system('shutdown -c')

    if not args.local:
        logger.info(
            f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}'
        )
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=util.get_world_size())
        assert (util.get_world_size() == dist.get_world_size())
        logger.info(
            f"Distributed: success ({args.local_rank}/{dist.get_world_size()})"
        )

    model = MemTransformerLM(ntokens,
                             args.n_layer,
                             args.n_head,
                             args.d_model,
                             args.d_head,
                             args.d_inner,
                             args.dropout,
                             args.dropatt,
                             tie_weight=args.tied,
                             d_embed=args.d_embed,
                             div_val=args.div_val,
                             tie_projs=tie_projs,
                             pre_lnorm=args.pre_lnorm,
                             tgt_len=args.tgt_len,
                             ext_len=args.ext_len,
                             mem_len=args.mem_len,
                             cutoffs=cutoffs,
                             same_length=args.same_length,
                             attn_type=args.attn_type,
                             clamp_len=args.clamp_len,
                             sample_softmax=args.sample_softmax)

    # log model info
    n_all_param = sum([p.nelement() for p in model.parameters()])
    log_tb('sizes/params', n_all_param)
    n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
    log_tb('sizes/non_emb_params', n_nonemb_param)
    logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param)

    # optimizer
    if args.optim.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.mom)
    elif args.optim.lower() == 'lamb':
        optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd)
    else:
        assert args.optim.lower() == 'adam'
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=args.wd)

    # scheduler
    if args.scheduler == 'cosine':
        # Divide by 1e6 for numerical stability.
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         args.max_tokens //
                                                         1e6,
                                                         eta_min=args.eta_min)
    elif args.scheduler == 'finder':
        scheduler = LRFinder(optimizer,
                             args.max_tokens,
                             init_value=args.lr / 1e3)
    elif args.scheduler == 'constant':
        pass

    model.apply(weights_init)
    model.word_emb.apply(
        weights_init
    )  # ensure embedding init is not overridden by out_layer in case of weight sharing

    if args.checkpoint:
        if global_rank == 0:
            util.restore_from_checkpoint(model=model,
                                         checkpoint_fn=args.checkpoint)

    model = model.to(device)
    if args.fp16:
        model = FP16_Module(model)
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.static_loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={'init_scale': 2**16},
                                   verbose=False)

    if args.local:
        model = nn.DataParallel(model, dim=1)
    else:
        # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding.
        model = DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank)  #, find_unused_parameters=True)

    if global_rank == 0:
        event_writer = SummaryWriter(args.logdir)

    event_writer.add_text('args', str(args))

    # test checkpoint writing
    if args.checkpoint_each_epoch:
        logger.info(f'Saving checkpoint for epoch {epoch}')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}')

    # Loop over epochs.
    train_step = 0
    train_loss = 0
    last_log_step = 0
    best_val_loss = None
    va_iter, te_iter = [
        corpus.get_dist_iterator(split,
                                 global_rank,
                                 max_rank,
                                 args.batch_size * 2,
                                 args.tgt_len,
                                 device=device,
                                 ext_len=args.ext_len)
        for split in ('valid', 'test')
    ]

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in itertools.count(start=1):
            train(va_iter, optimizer, scheduler)
    except KeyboardInterrupt:
        logger.info('-' * 100)
        logger.info('Exiting from training early')
    except StopIteration:
        pass

    # Eval one more time.
    evaluate_and_log(optimizer, va_iter, 'val', train_step=-1)

    # Load the best saved model.
    logger.info("Loading best checkpoint")
    model_file = os.path.join(args.logdir, 'model-best.pt')
    if os.path.exists(model_file):
        with open(model_file, 'rb') as model_f:
            with timeit('load'):
                if args.local:
                    model = torch.load(model_f)
                else:
                    model = torch.load(model_f,
                                       map_location=lambda storage, loc:
                                       storage.cuda(args.local_rank))
                    model = DistributedDataParallel(
                        model,
                        device_ids=[args.local_rank],
                        output_device=args.local_rank)
    else:
        logger.warn('no model file, using current model for loss')

    # Run on test data.
    evaluate_and_log(optimizer, te_iter, 'test', -1)
Esempio n. 5
0
    if distributed:
        torch.distributed.init_process_group(backend='nccl')
        rank = torch.distributed.get_rank()
        torch.cuda.set_device(rank)

    dataset = enwik8()
    if not fresh:
        with open('model.pt', 'rb') as f:
            model = torch.load(f)
    else:
        model = SHARNN(n_token=dataset.n_token, embed_dim=1024, hidden_dim=4096, ff_dim=2048, n_layers=4, heads=1, max_len=5000, dropout=0.1, tied=True)
    model.to(device)
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, dim=1, find_unused_parameters=True)

    # optim = torch.optim.Adam(model.parameters(), lr=0.002)
    from pytorch_lamb import Lamb
    optim = Lamb(model.parameters(), lr=0.002, min_trust=0.25)

    crit = nn.CrossEntropyLoss().to(device)
    # sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=2)
    scaler = torch.cuda.amp.GradScaler()

    if True:
        train(model, crit, optim, sched, dataset, epochs)

    test_loss = evaluate(model, crit, dataset.test_data)
    print(f"Test | loss {test_loss:.3f} | ppl {math.exp(test_loss):.3f} | bpc {test_loss / math.log(2):.3f}")
    exit()
        0.01
    }, {
        "params":
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        "weight_decay":
        0.0
    }]

    if args["adam"]:
        optimizer = BertAdam(optimizer_group_parameters,
                             lr=args["learning_rate"],
                             warmup=args["warmup_proportion"],
                             t_total=num_train_steps)
    else:
        optimizer = Lamb(optimizer_group_parameters,
                         lr=args["learning_rate"],
                         weight_decay=args["warmup_proportion"])

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    if args["do_train"]:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args["max_seq_length"],
                                                      tokenizer)
        logger.info("******** Running training ********")
        logger.info("  Number of examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args["train_batch_size"])
        logger.info("  Num steps = %d", num_train_steps)
def run_finetuning(args):
    torch.manual_seed(args.seed)
    device = torch.device(
        'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')

    # Get text columns
    t_columns = args.text_columns.split(',')
    num_texts = len(t_columns)
    if num_texts == 1: t_columns = t_columns[0]

    # Get label columns
    l_columns = args.label_columns.split(',')
    num_labels = len(l_columns)
    if num_labels == 1: l_columns = l_columns[0]

    if args.fp16 and not APEX_AVAILABLE:
        print(
            "FP16 toggle is on but Apex is not available. Using FP32 training."
        )

    if args.do_train:
        # Configure tokenizer
        tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
        if args.add_token != '':
            add_token = {
                'additional_special_tokens': args.add_token.split(',')
            }
            added = tokenizer.add_special_tokens(add_token)

        print('\n' + '=' * 50, '\nCONFIGURE FINETUNING SETUP', '\n' + '=' * 50)
        if args.add_token != '':
            print("Addded {} special tokens:".format(added), args.add_token)

        # Produce hash code for cache
        f_string = args.train_data + args.valid_data + str(args.msl) + str(
            args.seed) + args.pretrained + str(args.data_pct)
        hashed = 'cache_' + hashlib.md5(f_string.encode()).hexdigest() + '.pt'

        # Produce the dataset if cache doesn't exist
        if hashed not in os.listdir() or args.retokenize_data:
            print("Producing dataset cache. This will take a while.")
            s = time.time()

            df = pd.read_csv(args.train_data, lineterminator='\n').sample(
                frac=args.data_pct, random_state=args.seed)
            text, labels = df[t_columns].values, df[l_columns].values
            train_dataset = process_data(text, labels, tokenizer, msl=args.msl)

            df = pd.read_csv(args.valid_data, lineterminator='\n')
            text, labels = df[t_columns].values, df[l_columns].values
            valid_dataset = process_data(text, labels, tokenizer, msl=args.msl)

            if args.save_cache:
                print('Saving data cache')
                with open(hashed, 'wb') as f:
                    torch.save([train_dataset, valid_dataset], f)

            print("Preprocessing finished. Time elapsed: {:.2f}s".format(
                time.time() - s))

        # Load the dataset if the cache exists
        else:
            print('Cache found. Loading training and validation data.')
            with open(hashed, 'rb') as f:
                train_dataset, valid_dataset = torch.load(f)

        # Produce dataloaders
        train_sampler = data.RandomSampler(train_dataset)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       sampler=train_sampler)
        valid_loader = data.DataLoader(valid_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False)

        # Configure model
        config = AutoConfig.from_pretrained(
            args.pretrained, num_labels=2 if num_labels == 1 else num_labels)
        if args.random_init:
            print(
                "Initializing new randomly-initialized model from configuration"
            )
            model = AutoModelForSequenceClassification.from_config(config)
        else:
            print("Loading from pretrained checkpoint")
            model = AutoModelForSequenceClassification.from_pretrained(
                args.pretrained, config=config)
        _ = model.resize_token_embeddings(len(tokenizer))
        model = model.to(device)
        print("Model has {:,} trainable parameters".format(
            sum(p.numel() for p in model.parameters() if p.requires_grad)))

        # Configure loss function
        criterion = torch.nn.CrossEntropyLoss(
        ) if num_labels == 1 else torch.nn.BCEWithLogitsLoss()

        # Configure optimizer
        if args.optimizer == 'adam':
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [{
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                args.weight_decay
            }, {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0
            }]
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              eps=args.adam_epsilon)
            optimizer.zero_grad()
        elif args.optimizer == 'lamb':
            from pytorch_lamb import Lamb
            optimizer = Lamb(model.parameters(),
                             lr=args.learning_rate,
                             weight_decay=args.weight_decay,
                             betas=(args.adam_b1, args.adam_b2))

        # Configure FP16
        if args.fp16 and APEX_AVAILABLE:
            print("Using FP16 training.")
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.opt_level)

        # Configure scheduler
        if args.use_scheduler:
            steps = len(train_loader) * args.epochs // args.accumulation
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=int(steps * args.warmup_pct),
                num_training_steps=steps)
        else:
            scheduler = None

        print("Using learning rate {:.4E} and weight decay {:.4E}".format(
            args.learning_rate, args.weight_decay),
              end='')
        print(" with scheduler using warmup pct {}".format(
            args.warmup_pct)) if args.use_scheduler else print("")

        # Training proper
        print('\n' + '=' * 50, '\nTRAINING', '\n' + '=' * 50)
        print("Training batches: {} | Validation batches: {}".format(
            len(train_loader), len(valid_loader)))
        for e in range(1, args.epochs + 1):
            train_loss, train_acc = train(model,
                                          criterion,
                                          optimizer,
                                          train_loader,
                                          scheduler=scheduler,
                                          accumulation=args.accumulation,
                                          device=device,
                                          fp16=args.fp16)
            valid_loss, valid_acc = evaluate(model,
                                             criterion,
                                             valid_loader,
                                             device=device)
            print(
                "Epoch {:3} | Train Loss {:.4f} | Train Acc {:.4f} | Valid Loss {:.4f} | Valid Acc {:.4f}"
                .format(e, train_loss, train_acc, valid_loss, valid_acc))

            # Save the model
            model.save_pretrained(args.checkpoint)
            tokenizer.save_pretrained(args.checkpoint)
            #with open(args.checkpoint, 'wb') as f:
            #    torch.save(model.state_dict(), f)

    if args.do_eval:
        print('\n' + '=' * 50, '\nBEGIN EVALUATION PROPER', '\n' + '=' * 50)

        # Load saved tokenizer
        tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)

        # Produce hash code for test cache
        f_string = args.test_data + str(args.msl) + str(
            args.seed) + args.pretrained
        hashed = 'cache_' + hashlib.md5(f_string.encode()).hexdigest() + '.pt'

        # Produce the dataset if cache doesn't exist
        if hashed not in os.listdir() or args.retokenize_data:
            print("Producing test data cache. This will take a while.")
            s = time.time()

            df = pd.read_csv(args.test_data, lineterminator='\n')
            text, labels = df[t_columns].values, df[l_columns].values
            test_dataset = process_data(text, labels, tokenizer, msl=args.msl)

            if args.save_cache:
                print('Saving data cache')
                with open(hashed, 'wb') as f:
                    torch.save(test_dataset, f)

            print("Preprocessing finished. Time elapsed: {:.2f}s".format(
                time.time() - s))

        # Load the dataset if the cache exists
        else:
            print('Cache found. Loading test data.')
            with open(hashed, 'rb') as f:
                test_dataset = torch.load(f)

        # Dataloaders
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False)

        # Produce the model
        print("Loading finetuned checkpoint")
        model = AutoModelForSequenceClassification.from_pretrained(
            args.checkpoint)
        model = model.to(device)

        criterion = torch.nn.CrossEntropyLoss(
        ) if num_labels == 1 else torch.nn.BCEWithLogitsLoss()

        # Testing proper
        print('\n' + '=' * 50, '\nTESTING', '\n' + '=' * 50)
        test_loss, test_acc = evaluate(model,
                                       criterion,
                                       test_loader,
                                       device=device)
        print("Test Loss {:.4f} | Test Accuracy {:.4f}".format(
            test_loss, test_acc))

    # Logging
    if not args.do_train:
        train_loss, train_acc, valid_loss, valid_acc = None, None, None, None
    if not args.do_eval: test_loss, test_acc = None, None
    return train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc
Esempio n. 8
0
def train(args, train_dataset, model, tokenizer, pad_token_label_id):
    """
    Train the model
    """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    if args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    elif args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    else:
        raise Exception("Invalid optimizer specified")

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    #if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
    #    os.path.join(args.model_name_or_path, "scheduler.pt")
    #):
    # Load in optimizer and scheduler states
    #    optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
    #    scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.encoder_model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        if args.encoder_model_name_or_path.split("-")[-1].split(
                "/")[0].isdigit():
            global_step = int(
                args.encoder_model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info(
            "  Will skip the first %d steps in the first epoch",
            steps_trained_in_current_epoch,
        )

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            input_ids, output_ids, input_mask, output_mask, _, decoder_ids = batch

            # add other inputs here, including kwargs
            inputs = {"input_ids": input_ids, "attention_mask": input_mask}

            # The output tuple structure depends on the model used and the arguments invoked
            # For BERT-type models, this is
            # decoder_predictions, encoded_embeddings, encoded_attention_mask = model(**inputs)
            # For GPT2-type models, this at least starts with the decoder predictions
            # See the EncoderDecoderModel class for more details
            output = model(**inputs)
            decoder_predictions = output[0]

            vocab_size = decoder_predictions.shape[-1]

            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                decoder_predictions.view(-1, vocab_size),
                output_ids.view(-1),
            )

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results, _ = evaluate(args,
                                              model,
                                              tokenizer,
                                              pad_token_label_id,
                                              mode="dev")
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar(
                        "loss",
                        (tr_loss - logging_loss) / args.logging_steps,
                        global_step,
                    )
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Esempio n. 9
0
                               transforms.Normalize((0.5, 0.5, 0.5),
                                                    (0.5, 0.5, 0.5)),
                           ]))
print(args.batch_size)
dataloader = torch.utils.data.DataLoader(cifar10,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         drop_last=True)

net_G = Generator().to(device)
net_D = Discriminator().to(device)

optim_G = Lamb(net_G.parameters(),
               lr=args.lr,
               weight_decay=args.wd,
               betas=(.5, 0.999),
               adam=True)
optim_D = Lamb(net_D.parameters(),
               lr=args.lr,
               weight_decay=args.wd,
               betas=(.5, 0.999),
               adam=True)

train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
valid_writer = SummaryWriter(os.path.join(log_dir, 'valid'))

real_label = torch.full((args.batch_size, 1), 1).to(device)
fake_label = torch.full((args.batch_size, 1), 0).to(device)
label = torch.cat([real_label, fake_label], dim=0)
criteria = nn.BCEWithLogitsLoss()
Esempio n. 10
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--optimizer',
                        type=str,
                        default='lamb',
                        choices=['lamb', 'adam'],
                        help='which optimizer to use')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=6,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0025,
                        metavar='LR',
                        help='learning rate (default: 0.0025)')
    parser.add_argument('--wd',
                        type=float,
                        default=0.01,
                        metavar='WD',
                        help='weight decay (default: 0.01)')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    args = parser.parse_args()
    use_cuda = torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    model = Net().to(device)
    if args.optimizer == 'lamb':
        # When using extremely high lr such as 0.1, wd helps avoid diverging.
        # Also use better beta2 from https://www.fast.ai/2018/07/02/adam-weight-decay/
        optimizer = Lamb(model.parameters(),
                         lr=args.lr,
                         weight_decay=args.wd,
                         betas=(.9, .99))
    else:
        # Don't actually use the calculated trust ratio, which makes this equivalent to Adam.
        optimizer = Lamb(model.parameters(),
                         lr=args.lr,
                         weight_decay=args.wd,
                         betas=(.9, .99),
                         adam=True)
    writer = SummaryWriter()
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch, writer)
        test(args, model, device, test_loader, writer, epoch)
class CycleGAN2Model(BaseModel):
    """
    This class implements the CycleGAN model, for learning image-to-image translation without paired data.

    The model training requires '--dataset_mode unaligned' dataset.
    By default, it uses a '--netG resnet_9blocks' ResNet generator,
    a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
    and a least-square GANs objective ('--gan_mode lsgan').

    CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
        Forward cycle loss:  lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
        Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
        Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
        Dropout is not used in the original CycleGAN paper.
        """
        parser.set_defaults(
            no_dropout=True)  # default CycleGAN did not use dropout
        if is_train:
            parser.add_argument(
                '--optimizer',
                type=str,
                default='adam',
                help='optimizer to use [adam|lamb default: adam')
            parser.add_argument('--wd',
                                type=float,
                                default=0,
                                help='weight decay default: 0')
            parser.add_argument('--lambda_A',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B',
                                type=float,
                                default=10.0,
                                help='weight for cycle loss (B -> A -> B)')
            parser.add_argument(
                '--lambda_identity',
                type=float,
                default=0.5,
                help=
                'use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1'
            )

        return parser

    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
        ]
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            opt.init_type, opt.init_gain,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            opt.init_type, opt.init_gain,
                                            self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert (opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            #self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            #self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_G = Lamb(itertools.chain(self.netG_A.parameters(),
                                                    self.netG_B.parameters()),
                                    lr=opt.lr,
                                    weight_decay=opt.wd,
                                    adam=(opt.optimizer == 'adam'),
                                    betas=(opt.beta1, 0.999))
            self.optimizer_D = Lamb(itertools.chain(self.netD_A.parameters(),
                                                    self.netD_B.parameters()),
                                    lr=opt.lr,
                                    weight_decay=opt.wd,
                                    adam=(opt.optimizer == 'adam'),
                                    betas=(opt.beta1, 0.999))

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        a, b = ('A', 'B') if AtoB else ('B', 'A')
        self.real_A = input[a].to(self.device)
        self.real_B = input[b].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

        if 'clean' + a in input and 'aug' + a in input:
            self.clean_A = input['clean' + a].to(self.device)
            self.aug_A = input['aug' + a].to(self.device)

        if 'clean' + b in input and 'aug' + b in input:
            self.clean_B = input['clean' + b].to(self.device)
            self.aug_B = input['aug' + b].to(self.device)

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)  # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)  # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(
                self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(
                self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A,
                                                self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B,
                                                self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()  # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad(
            [self.netD_A, self.netD_B],
            False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()  # calculate gradients for G_A and G_B
        self.optimizer_G.step()  # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()  # set D_A and D_B's gradients to zero
        self.backward_D_A()  # calculate gradients for D_A
        self.backward_D_B()  # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights
    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
        ]
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf,
                                        opt.netG, opt.norm, not opt.no_dropout,
                                        opt.init_type, opt.init_gain,
                                        self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            opt.init_type, opt.init_gain,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm,
                                            opt.init_type, opt.init_gain,
                                            self.gpu_ids)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert (opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            #self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            #self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizer_G = Lamb(itertools.chain(self.netG_A.parameters(),
                                                    self.netG_B.parameters()),
                                    lr=opt.lr,
                                    weight_decay=opt.wd,
                                    adam=(opt.optimizer == 'adam'),
                                    betas=(opt.beta1, 0.999))
            self.optimizer_D = Lamb(itertools.chain(self.netD_A.parameters(),
                                                    self.netD_B.parameters()),
                                    lr=opt.lr,
                                    weight_decay=opt.wd,
                                    adam=(opt.optimizer == 'adam'),
                                    betas=(opt.beta1, 0.999))

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
Esempio n. 13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_dir', type=str, help='Directory of training samples')
    parser.add_argument('--valid_dir', type=str, help='Directory of validation samples')
    parser.add_argument('--test_dir', type=str, help='Directory of test samples')
    parser.add_argument('--src_vocab', type=str, help='SentencePiece vocabulary file for source sentence')
    parser.add_argument('--trg_vocab', type=str, help='SentencePiece vocabulary file for target sentence')
    parser.add_argument('--src_msl', type=int, default=100, help='Maximum sequence length for source sentence')
    parser.add_argument('--trg_msl', type=int, default=100, help='Maximum sequence length for target sentence')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=1, help='Number of parallel workers for dataloading')
    parser.add_argument('--save_dir', type=str, help='Directory to save checkpoints and load checkpoints from')

    parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train')
    parser.add_argument('--tie_weights', action='store_true', help='Tie weights of encoder/decoder embeddings and projection layer')
    parser.add_argument('--hidden_dim', type=int, default=256, help='Hidden dimensions of the transformer layers')
    parser.add_argument('--n_layers', type=int, default=3, help='Number of transformer blocks in the encoder and decoder')
    parser.add_argument('--n_heads', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--pf_dim', type=int, default=512, help='Positionwise feedforward projection dimension')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout probability')
    parser.add_argument('--clip', type=float, default=1.0, help='Gradient clipping')

    parser.add_argument('--criterion', type=str, default='cross_entropy', choices=['cross_entropy', 'label_smoothing'], help='Criterion to use')
    parser.add_argument('--smoothing', type=float, default=0.0, help='Label smoothing factor')
    parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'adamw', 'lamb'], help='Optimizer to use')
    parser.add_argument('--learning_rate', type=float, default=3e-4, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay for non LayerNorm and Bias layers')
    parser.add_argument('--adam_epsilon', type=float, default=1e-9, help='Epsilon value for Adam')
    parser.add_argument('--adam_b1', type=float, default=0.9, help='Beta1 for LAMB')
    parser.add_argument('--adam_b2', type=float, default=0.99, help='Beta2 for LAMB')
    parser.add_argument('--scheduler', type=str, default=None, choices=['cosine', 'linear', 'noam', None], help='Scheduler to use')
    parser.add_argument('--warmup_pct', type=float, default=0.1, help='Percentage of steps to warmup for linear scheduler')
    parser.add_argument('--warmup_steps', type=int, default=4000, help='Number of warmup steps for noam scheduling')
    
    parser.add_argument('--do_train', action='store_true', help='Train a model')
    parser.add_argument('--do_test', action='store_true', help='Evaluate a model')
    parser.add_argument('--resume_training', action='store_true', help='Toggle to resume from checkpoint in --save_dir')
    parser.add_argument('--no_cuda', action='store_true', help='Do not use GPU')
    parser.add_argument('--fp16', action='store_true', help='Use FP16 training via APEX')
    parser.add_argument('--opt_level', type=str, default='O1', choices=['O1', 'O2'], help='Optimization level for FP16 training')
    parser.add_argument('--save_every', type=int, default=1, help='Save checkpoint every --save_every epoch')
    parser.add_argument('--pad_token', type=str, default='<pad>', help='Override default padding token')

    parser.add_argument('--use_swa', action='store_true', help='Use stochastic weight averaging')
    parser.add_argument('--swa_pct', type=float, help='Last percentage of total training steps to average')
    parser.add_argument('--swa_times', type=int, help='Number of times to average over swa_pct')

    parser.add_argument('--seed', type=int, default=1111, help='Random seed')
    
    args = parser.parse_args()
    print(args)

    # Set seeds
    torch.manual_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
    
    if args.do_train:
        # Produce dataloaders
        print("Producing dataloaders.")
        train_dataset = TextDataset(args.train_dir, args.src_vocab, args.trg_vocab, src_msl=args.src_msl, trg_msl=args.trg_msl)
        train_sampler = torch.utils.data.RandomSampler(train_dataset)
        train_loader = torch.utils.data.DataLoader(train_dataset, 
                                                  sampler=train_sampler,
                                                  batch_size=args.batch_size, 
                                                  collate_fn=collate_fn, 
                                                  num_workers=args.num_workers)

        valid_dataset = TextDataset(args.valid_dir, args.src_vocab, args.trg_vocab, src_msl=args.src_msl, trg_msl=args.trg_msl)
        valid_loader = torch.utils.data.DataLoader(valid_dataset, 
                                                  shuffle=False,
                                                  batch_size=args.batch_size, 
                                                  collate_fn=collate_fn, 
                                                  num_workers=args.num_workers)

        print("Training batches: {}\nValidation batches: {}".format(len(train_loader), len(valid_loader)))

        # Produce model and criterion
        encoder = Encoder(vocab_sz=train_dataset.src_vocab_sz, 
                          hidden_dim=args.hidden_dim, 
                          n_layers=args.n_layers, 
                          n_heads=args.n_heads, 
                          pf_dim=args.pf_dim, 
                          dropout=args.dropout, 
                          msl=args.src_msl, 
                          fp16=args.fp16)
        decoder = Decoder(vocab_sz=train_dataset.trg_vocab_sz, 
                          hidden_dim=args.hidden_dim, 
                          n_layers=args.n_layers, 
                          n_heads=args.n_heads, 
                          pf_dim=args.pf_dim, 
                          dropout=args.dropout, 
                          msl=args.trg_msl, 
                          fp16=args.fp16)
        model = Seq2Seq(encoder, decoder, train_dataset.src_word2idx[args.pad_token], train_dataset.trg_word2idx[args.pad_token], tie_weights=args.tie_weights).to(device)
        
        # Configure SWA
        if args.use_swa: swa_model = AveragedModel(model)
        else: swa_model = None

        # Produce criterion
        if args.criterion == 'cross_entropy':
            criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.src_word2idx[args.pad_token])
        elif args.criterion == 'label_smoothing':
            criterion = LabelSmoothingLoss(epsilon=args.smoothing)

        # Produce Optimizer
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 
                                         "weight_decay": args.weight_decay}, 
                                        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 
                                         "weight_decay": 0.0}]

        if args.optimizer == 'adamw':
            try:
                from transformers import AdamW
                optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=(args.adam_b1, args.adam_b2))
            except ModuleNotFoundError:
                print("Transformers module not found for AdamW. Using generic Adam instead.")
                optimizer = optim.Adam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=(args.adam_b1, args.adam_b2))
        elif args.optimizer == 'lamb':
            try:
                from pytorch_lamb import Lamb
                optimizer = Lamb(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(args.adam_b1, args.adam_b2))
            except ModuleNotFoundError:
                print("LAMB implementation not found. Using generic Adam instead.")
                optimizer = optim.Adam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=(args.adam_b1, args.adam_b2))
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=(args.adam_b1, args.adam_b2))

        # Configure FP16
        if args.fp16 and APEX_AVAILABLE:
            model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level)

        # Produce the scheduler
        if args.scheduler == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader) * args.epochs)
        elif args.scheduler == 'linear':
            try:
                from transformers import get_linear_schedule_with_warmup
                steps = args.epochs * len(train_loader)
                scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(steps * args.warmup_pct), num_training_steps=steps)
            except ModuleNotFoundError:
                print('Transformers module not found for Linear Schedule. Not using a scheduler instead.')
                scheduler = None
        elif args.scheduler == 'noam':
            scheduler = NoamLR(optimizer, warmup_steps=args.warmup_steps)
        else:
            scheduler = None

        print("\nUsing {} optimizer with {} scheduling. Optimizing via {}.".format(str(type(optimizer)), str(type(scheduler)), str(type(criterion))))
        print("Model has {:,} trainable parameters.\n".format(count_parameters(model)))

        # Configure states if resuming from checkpoint
        if args.resume_training:
            print("Loading from checkpoint...", end='')
            with open(args.save_dir + '/model.bin', 'rb') as f:
                model.load_state_dict(torch.load(f))
                print('Model loaded...', end='')
            if args.use_swa:
                with open(args.save_dir + '/swa_model.bin', 'rb') as f:
                    swa_model.load_state_dict(torch.load(f))
                    print('SWA Model loaded...', end='')
            with open(args.save_dir + '/training.bin', 'rb') as f:
                training_state = torch.load(f)
                optimizer.load_state_dict(training_state['opt_state'])
                e = training_state['e'] + 1 # Start on the next epoch
                print('Optimizer loaded...', end='')

                if training_state['scheduler'] is not None:
                    scheduler.load_state_dict(training_state['scheduler'])
                    print('Scheduler loaded...', end='')
                else:
                    print('No scheduler found...')
            global_steps = len(train_loader) * (e - 1)
            print("Done!\nLoaded checkpoint from epoch {} | Global step {}!".format(training_state['e'], global_steps))
            
        # Else, begin from epoch 1
        else:
            print("Beginning training from epoch 1.")
            e = 1
            global_steps = 0

        # Print training setup
        total_steps = len(train_loader) * (args.epochs)
        print('Total number of steps: {}'.format(total_steps))
        
        # Configure SWA points
        if args.use_swa: 
            swa_every = sorted(list(set([round(total_steps * (1 - args.swa_pct * i)) for i in range(args.swa_times)])))
            print('SWA on steps: {}\n'.format(swa_every)) 
        else: 
            swa_every = None
            print('\n')

        # Train Model
        while e < args.epochs + 1:
            # Train one epoch
            train_loss, global_steps = train(model, criterion, optimizer, train_loader, global_steps, 
                                            device=device, clip=args.clip, scheduler=scheduler, fp16=args.fp16, 
                                            swa=args.use_swa, swa_every=swa_every, swa_model=swa_model)
            valid_loss = evaluate(model, criterion, valid_loader, device=device)

            print("Epoch {:3} | Train Loss {:.4f} | Train Ppl {:.4f} | Valid Loss {:.4f} | Valid Ppl {:.4f}".format(e, train_loss, np.exp(train_loss), valid_loss, np.exp(valid_loss)))

            # Save the checkpoint
            if e % args.save_every == 0 or e == args.epochs:
                print('Saving checkpoint for epoch {}...'.format(e), end='')
                save_checkpoint(model, args, optimizer=optimizer, e=e, scheduler=scheduler, save_state=True, swa_model=swa_model)
                print('Done!')
            
            # Update epoch number
            e += 1

        # Evaluate again and save if we're using SWA
        if args.use_swa:
            print('Evaluating on final averaged model.')
            valid_loss = evaluate(swa_model, criterion, valid_loader, device=device)
            print("Valid Loss {:.4f} | Valid Ppl {:.4f}".format(valid_loss, np.exp(valid_loss)))
            print('Saving checkpoint for averaged model.')
            save_checkpoint(model, args, optimizer=optimizer, e=e, scheduler=scheduler, save_state=True, swa_model=swa_model)
            print('Done!')

        print("Training done!\n")

    if args.do_test:    
        # Produce dataloaders
        print("Producing test loaders.")
        test_dataset = TextDataset(args.test_dir, args.src_vocab, args.trg_vocab, src_msl=args.src_msl, trg_msl=args.trg_msl)
        test_loader = torch.utils.data.DataLoader(test_dataset, 
                                                  shuffle=False,
                                                  batch_size=args.batch_size, 
                                                  collate_fn=collate_fn, 
                                                  num_workers=args.num_workers)
        
        print("Number of testing batches: {}".format(len(test_loader)))

        # Produce setup
        print("Loading model and saved settings.")
        with open(args.save_dir + '/settings.bin', 'rb') as f:
            hd, nl, nh, pf, dp, smsl, tmsl, tw, usw, cri = torch.load(f)

        encoder = Encoder(vocab_sz=test_dataset.src_vocab_sz, 
                          hidden_dim=hd, 
                          n_layers=nl, 
                          n_heads=nh, 
                          pf_dim=pf, 
                          dropout=dp, 
                          msl=smsl, 
                          fp16=args.fp16)
        decoder = Decoder(vocab_sz=test_dataset.trg_vocab_sz, 
                          hidden_dim=hd, 
                          n_layers=nl, 
                          n_heads=nh, 
                          pf_dim=pf, 
                          dropout=dp, 
                          msl=tmsl, 
                          fp16=args.fp16)
        model = Seq2Seq(encoder, decoder, test_dataset.src_word2idx[args.pad_token], test_dataset.trg_word2idx[args.pad_token], tie_weights=tw)

        if cri == 'cross_entropy':
            criterion = nn.CrossEntropyLoss(ignore_index=test_dataset.trg_word2idx[args.pad_token])
        elif cri == 'label_smoothing':
            criterion = LabelSmoothingLoss(epsilon=args.smoothing)

        # Load the checkpoint
        if usw:
            swa_model = AveragedModel(model)
            with open(args.save_dir + '/swa_model.bin', 'rb') as f:
                swa_model.load_state_dict(torch.load(f))
                swa_model = swa_model.to(device)

        with open(args.save_dir + '/model.bin', 'rb') as f:
            model.load_state_dict(torch.load(f))
        model = model.to(device)
        
        print("\nBegin testing.")
        test_loss = evaluate(model, criterion, test_loader, device=device)
        print("Test Loss {:.4f} | Test Ppl {:.4f}".format(test_loss, np.exp(test_loss)))

        if usw:
            print('Testing SWA model.')
            test_loss = evaluate(swa_model, criterion, test_loader, device=device)
            print("Test Loss {:.4f} | Test Ppl {:.4f}".format(test_loss, np.exp(test_loss)))