def plot_latent_interpolations(self, attr_str, dim, num_points=10):
     x1 = torch.linspace(-4, 4.0, num_points)
     _, _, data_loader = self.dataset.data_loaders(batch_size=1)
     for sample_id, batch in tqdm(enumerate(data_loader)):
         if sample_id in [0, 1, 2]:
             inputs, labels = self.process_batch_data(batch)
             inputs = to_cuda_variable(inputs)
             recons, _, _, z, _ = self.model(inputs)
             recons = torch.sigmoid(recons)
             z = z.repeat(num_points, 1)
             z[:, dim] = x1.contiguous()
             outputs = torch.sigmoid(self.model.decode(z))
             # save interpolation
             save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'latent_interpolations_{attr_str}_{sample_id}.png')
             save_image(outputs.cpu(),
                        save_filepath,
                        nrow=num_points,
                        pad_value=1.0)
             # save original image
             org_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'original_{sample_id}.png')
             save_image(inputs.cpu(),
                        org_save_filepath,
                        nrow=1,
                        pad_value=1.0)
             # save reconstruction
             recons_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'recons_{sample_id}.png')
             save_image(recons.cpu(),
                        recons_save_filepath,
                        nrow=1,
                        pad_value=1.0)
         if sample_id == 5:
             break
Ejemplo n.º 2
0
def train(model: tf.keras.Model, dataset: DatasetCreator):
    optimizer = tf.keras.optimizers.RMSprop(learning_rate=ModelConfig.LR)
    loss_fn = tf.losses.SparseCategoricalCrossentropy(from_logits=False)
    trainer = Trainer(model, optimizer, loss_fn, dataset)

    if DataConfig.USE_TB:
        tensorboard = TensorBoard(tb_dir=DataConfig.TB_DIR)
        # Creates the TensorBoard Graph. The result is not great, but it's a start.
        tf.summary.trace_on(graph=True, profiler=False)
        trainer.val_step(tf.expand_dims(tf.zeros(dataset.input_shape), 0), 0)
        with tensorboard.file_writer.as_default():
            tf.summary.trace_export(name=ModelConfig.NETWORK_NAME, step=0)

    best_loss = 1000
    last_checkpoint_epoch = 0

    for epoch in range(ModelConfig.MAX_EPOCHS):
        print(f"\nEpoch {epoch}/{ModelConfig.MAX_EPOCHS}")
        epoch_start_time = time.time()

        train_loss, train_acc = trainer.train_epoch()

        if DataConfig.USE_TB:
            tensorboard.write_metrics(train_loss, train_acc, epoch)
            tensorboard.write_lr(ModelConfig.LR, epoch)

        if (train_loss < best_loss and DataConfig.USE_CHECKPOINT and
                epoch >= DataConfig.RECORD_DELAY and (epoch - last_checkpoint_epoch) >= DataConfig.CHECKPT_SAVE_FREQ):

            save_path = os.path.join(DataConfig.CHECKPOINT_DIR, f'train_{epoch}')
            print(f"Loss improved from {best_loss} to {train_loss}, saving model to {save_path}")
            best_loss = train_loss
            last_checkpoint_epoch = epoch
            model.save(save_path)

        print(f"\nEpoch loss: {train_loss}, Train accuracy: {train_acc}  -  Took {time.time() - epoch_start_time:.5f}s")

        # Validation and (expensive to compute) metrics
        if epoch % DataConfig.VAL_FREQ == 0 and epoch > DataConfig.RECORD_DELAY:
            validation_start_time = time.time()
            val_loss, val_acc = trainer.val_epoch()

            if DataConfig.USE_TB:
                tensorboard.write_metrics(val_loss, val_acc, epoch, mode="Validation")

                # Metrics for the Train dataset
                imgs, labels = list(dataset.train_dataset.take(1).as_numpy_iterator())[0]
                predictions = model.predict(imgs)
                tensorboard.write_predictions(imgs, predictions, labels, epoch, mode="Train")

                # Metrics for the Validation dataset
                imgs, labels = list(dataset.val_dataset.take(1).as_numpy_iterator())[0]
                predictions = model.predict(imgs)
                tensorboard.write_predictions(imgs, predictions, labels, epoch, mode="Validation")

            print(f"\nValidation loss: {val_loss}, Validation accuracy: {val_acc}",
                  f"-  Took {time.time() - validation_start_time:.5f}s", flush=True)
Ejemplo n.º 3
0
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample,
          checkpoint_epochs):
    trainer = Trainer(use_cuda, wandb_name)
    trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs)
    trainer.setup_wandb(
        WANDB_PROJECT,
        wandb_name,
        config={
            "Batch Size": batch_size,
            "Epochs": num_epochs,
            "Adam Betas": ADAM_BETAS,
            "Learning Rate": LEARNING_RATE,
            "Weight Decay": WEIGHT_DECAY,
            "Fine Tuning": False,
        },
    )
    train_loader, test_loader = trainer.load_data_loaders(
        Dataset, batch_size, subsample)
    trainer.register_loss_fn(get_mse_loss)
    trainer.register_metric_fn(get_mse_metric, "Loss")
    trainer.input_shape = [2**15]
    trainer.target_shape = [2**15]
    trainer.output_shape = [2**15]
    net = trainer.load_net(WaveUNet)

    opt_kwargs = {
        "adam_betas": ADAM_BETAS,
        "weight_decay": WEIGHT_DECAY,
    }

    # Set net to train on clean speech only as an autoencoder
    net.skips_enabled = False
    trainer.test_set.clean_only = True
    trainer.train_set.clean_only = True

    # Fiddle with learning rate because autoencoder is not very good w/o skip conns.
    optimizer = trainer.load_optimizer(net, learning_rate=1e-4, **opt_kwargs)
    trainer.train(net, 5, optimizer, train_loader, test_loader)

    optimizer = trainer.load_optimizer(net, learning_rate=1e-5, **opt_kwargs)
    trainer.train(net, 5, optimizer, train_loader, test_loader)

    optimizer = trainer.load_optimizer(net, learning_rate=1e-6, **opt_kwargs)
    trainer.train(net, 5, optimizer, train_loader, test_loader)

    # Set net to train on noisy speech
    optimizer = trainer.load_optimizer(
        net,
        learning_rate=LEARNING_RATE,
        **opt_kwargs,
    )
    # net.freeze_encoder()
    net.skips_enabled = True
    trainer.test_set.clean_only = False
    trainer.train_set.clean_only = False
    trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
Ejemplo n.º 4
0
def main(args):
    config = tf.ConfigProto(
        allow_soft_placement=args.allow_soft_placement,
        gpu_options=tf.GPUOptions(allow_growth=args.allow_gpu_growth))

    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    if args.no_trans_repr:
        kbqa_model = None
        trans_sess = None
    else:
        # load transferred model params
        config_path = "%s/config.json" % args.model_dir
        with open(config_path, 'r') as fr:
            kbqa_model_config = json.load(fr)

        trans_graph = tf.Graph()
        with trans_graph.as_default():
            kbqa_model = KbqaModel(**kbqa_model_config)
            trans_saver = tf.train.Saver()

        trans_sess = tf.Session(config=config, graph=trans_graph)
        model_path = '%s/model_best/best.model' % args.model_dir
        trans_saver.restore(trans_sess, save_path=model_path)

    if args.no_trans_select:
        feed_parm = None
    else:
        param_path = '%s/detail/param.best.pkl' % args.model_dir
        with open(param_path, 'rb') as fr:
            param_dict = pickle.load(fr)
            if len(param_dict.keys()) == 1:
                feed_parm = {
                    'bilinear_mat':
                    param_dict['rm_task/rm_forward/bilinear_mat']
                }
                print("bilinear_mat:", feed_parm['bilinear_mat'].shape)
            else:
                feed_parm = {
                    'fc1_weights':
                    param_dict['rm_task/rm_forward/fc1/weights'],
                    'fc1_biases': param_dict['rm_task/rm_forward/fc1/biases'],
                    'fc2_weights':
                    param_dict['rm_task/rm_forward/fc2/weights'],
                    'fc2_biases': param_dict['rm_task/rm_forward/fc2/biases']
                }
                print("fc1_weights:", feed_parm['fc1_weights'].shape)
                print("fc1_biases:", feed_parm['fc1_biases'].shape)
                print("fc2_weights:", feed_parm['fc2_weights'].shape)
                print("fc2_biases:", feed_parm['fc2_biases'].shape)

    # load knowledge
    kd_loader = KnowledgeLoader(args.kd_dir)
    word_vocab, word_embed = kd_loader.load_vocab(vocab_size=args.vocab_size,
                                                  embed_dim=args.dim_emb)
    kd_vocab, kd_embed = kd_loader.load_entity_relation()
    csk_entity_list = kd_loader.load_csk_entities()

    main_graph = tf.Graph()
    with main_graph.as_default():
        use_trans_repr = False if args.no_trans_repr else True
        use_trans_select = False if args.no_trans_select else True
        use_guiding = False if args.no_use_guiding else True
        if args.multi_step:
            model = TransDGModelMultistep(word_embed,
                                          kd_embed,
                                          feed_parm,
                                          use_trans_repr=use_trans_repr,
                                          use_trans_select=use_trans_select,
                                          use_guiding=use_guiding,
                                          vocab_size=args.vocab_size,
                                          dim_emb=args.dim_emb,
                                          dim_trans=args.dim_trans,
                                          cell_class=args.cell_class,
                                          num_units=args.num_units,
                                          num_layers=args.num_layers,
                                          max_length=args.max_dec_len,
                                          lr_rate=args.lr_rate,
                                          max_grad_norm=args.max_grad_norm,
                                          drop_rate=args.drop_rate)
        else:
            model = TransDGModel(word_embed,
                                 kd_embed,
                                 feed_parm,
                                 use_trans_repr=use_trans_repr,
                                 use_trans_select=use_trans_select,
                                 vocab_size=args.vocab_size,
                                 dim_emb=args.dim_emb,
                                 dim_trans=args.dim_trans,
                                 cell_class=args.cell_class,
                                 num_units=args.num_units,
                                 num_layers=args.num_layers,
                                 max_length=args.max_dec_len,
                                 lr_rate=args.lr_rate,
                                 max_grad_norm=args.max_grad_norm,
                                 drop_rate=args.drop_rate)
        saver = tf.train.Saver(max_to_keep=5)
        best_saver = tf.train.Saver()

    sess = tf.Session(config=config, graph=main_graph)

    if args.mode == 'train':
        if tf.train.get_checkpoint_state("%s/models" % args.log_dir):
            model_path = tf.train.latest_checkpoint("%s/models" % args.log_dir)
            print("model restored from [%s]" % model_path)
            saver.restore(sess, model_path)
        else:
            print("create model with init parameters...")
            with main_graph.as_default():
                sess.run(tf.global_variables_initializer())
                model.set_vocabs(sess, word_vocab, kd_vocab)
                if args.verbose > 0:
                    model.show_parameters()

        train_chunk_list = []
        with open("%s/all_list" % args.data_dir, 'r') as fr:
            for line in fr:
                train_chunk_list.append(line.strip())
        train_batcher = DataBatcher(data_dir=args.data_dir,
                                    file_list=train_chunk_list,
                                    batch_size=args.batch_size,
                                    num_epoch=args.max_epoch,
                                    shuffle=True)
        # wait for train_batcher queue caching
        print("Loading data from [%s/all_list]" % args.data_dir)
        while not train_batcher.full():
            time.sleep(5)
            print("loader queue caching...")

        valid_loader = DataLoader(batch_size=args.batch_size)
        valid_path = "%s/valid.pkl" % args.data_dir
        valid_loader.load_data(file_path=valid_path)

        # train model
        trainer = Trainer(model=model,
                          sess=sess,
                          trans_model=kbqa_model,
                          trans_sess=trans_sess,
                          saver=saver,
                          best_saver=best_saver,
                          log_dir=args.log_dir,
                          save_per_step=args.save_per_step)

        for epoch_idx in range(args.max_epoch):
            print("Epoch %d:" % (epoch_idx + 1))
            trainer.train(train_batcher, valid_loader, epoch_idx=epoch_idx)

    else:
        if args.ckpt == 'best.model':
            model_path = "%s/models/best_model/best.model" % args.log_dir
        else:
            model_path = "%s/models/model-%s" % (args.log_dir, args.ckpt)
        print("model restored from [%s]" % model_path)
        saver.restore(sess, model_path)

        test_loader = DataLoader(batch_size=args.batch_size)
        test_path = "%s/test.pkl" % args.data_dir
        test_loader.load_data(file_path=test_path)

        with open('%s/stopwords' % args.kd_dir, 'r') as f:
            stop_words = json.loads(f.readline())

        # test model on test set
        generator = Generator(model=model,
                              sess=sess,
                              trans_model=kbqa_model,
                              trans_sess=trans_sess,
                              log_dir=args.log_dir,
                              ckpt=args.ckpt,
                              stop_words=stop_words,
                              csk_entities=csk_entity_list)
        generator.generate(test_loader)
Ejemplo n.º 5
0
def run_train(args):
    # read data
    task_data = TaskData(task="task1")
    processor = BertProcessor(vocab_path="%s/vocab.txt" % args.bert_model_dir,
                              do_lower_case=True)

    train_path = "%s/train.data.txt" % args.data_dir
    train_data = task_data.read_data(data_path=train_path,
                                     is_train=True,
                                     with_pattern=True,
                                     shuffle=True)
    train_examples = processor.create_examples(lines=train_data,
                                               example_type='train')
    train_features = processor.create_features(examples=train_examples,
                                               max_seq_len=args.max_seq_len)
    train_dataset = processor.create_dataset(train_features,
                                             with_pattern=True,
                                             is_sorted=args.sorted)
    train_sampler = SequentialSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  collate_fn=collate_fn)

    valid_path = "%s/dev.data.txt" % args.data_dir
    valid_data = task_data.read_data(data_path=valid_path,
                                     is_train=True,
                                     with_pattern=True,
                                     shuffle=False)
    valid_examples = processor.create_examples(lines=valid_data,
                                               example_type='valid')
    valid_features = processor.create_features(examples=valid_examples,
                                               max_seq_len=args.max_seq_len)
    valid_dataset = processor.create_dataset(valid_features, with_pattern=True)
    valid_sampler = SequentialSampler(valid_dataset)
    valid_dataloader = DataLoader(valid_dataset,
                                  sampler=valid_sampler,
                                  batch_size=args.batch_size,
                                  collate_fn=collate_fn)
    # save vocab to log dir
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    processor.tokenizer.save_pretrained(args.log_dir)
    torch.save(args, "%s/train_args.bin" % (args.log_dir))
    logging.info("Training/evaluation parameters %s", args)

    # set model
    logging.info("initializing model")
    if args.resume_path:
        args.resume_path = Path(args.resume_path)
        model = BertForClassifier.from_pretrained(
            args.resume_path, num_labels=task_data.get_num_labels())
    else:
        model = BertForClassifier.from_pretrained(
            args.bert_model_dir, num_labels=task_data.get_num_labels())

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = int(
        len(train_dataloader) / args.gradient_accumulation_steps *
        args.num_epochs)
    warmup_steps = int(t_total * args.warmup_proportion)
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)

    model_checkpoint = ModelCheckpoint(checkpoint_dir=args.log_dir,
                                       mode=args.mode,
                                       save_best_only=args.save_best)

    # training
    logging.info("======= Running training =======")
    logging.info("Num examples = %d" % len(train_examples))
    logging.info("Num epochs = %d" % args.num_epochs)
    logging.info("Gradient accumulation steps = %d" %
                 args.gradient_accumulation_steps)
    logging.info("Total optimization steps = %d" % t_total)

    trainer = Trainer(
        model=model,
        num_epochs=args.num_epochs,
        criterion=CrossEntropyLoss(),
        optimizer=optimizer,
        scheduler=scheduler,
        model_checkpoint=model_checkpoint,
        batch_metrics=[F1Score(task_type='multiclass', average='micro')],
        epoch_metrics=[F1Score(task_type='multiclass', average='micro')])
    trainer.train(train_data=train_dataloader, valid_data=valid_dataloader)
Ejemplo n.º 6
0
def test_train_with_lr_scheduler(mock_checkpoint):
    """
    Check that training loop runs without crashing, when there is no model
    and when there is a learning rate scheulder used
    """
    trainer = Trainer(use_cuda=USE_CUDA, wandb_name="my-model")
    trainer.setup_checkpoints("my-checkpoint", checkpoint_epochs=None)
    train_loader, test_loader = trainer.load_data_loaders(
        DummyDataset,
        batch_size=16,
        subsample=None,
        build_output=_build_output,
        length=128,
    )
    trainer.register_loss_fn(_get_mse_loss)
    trainer.register_metric_fn(_get_mse_metric, "Loss")
    trainer.input_shape = [1, 80, 256]
    trainer.target_shape = [1, 80, 256]
    trainer.output_shape = [1, 80, 256]
    net = trainer.load_net(
        DummyNet,
        input_shape=(16,) + INPUT_SHAPE,
        output_shape=(16,) + OUTPUT_SHAPE,
        use_cuda=USE_CUDA,
    )
    optimizer = trainer.load_optimizer(
        net, learning_rate=1e-4, adam_betas=[0.9, 0.99], weight_decay=1e-6
    )
    epochs = 5

    # One cycle learning rate
    steps_per_epoch = len(trainer.train_set) // 16
    trainer.use_one_cycle_lr_scheduler(optimizer, steps_per_epoch, epochs, 1e-3)

    mock_checkpoint.save.assert_not_called()
    trainer.train(net, epochs, optimizer, train_loader, test_loader)
    mock_checkpoint.save.assert_called_once_with(
        net, "my-checkpoint", name="my-model", use_wandb=False
    )
Ejemplo n.º 7
0
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample,
          checkpoint_epochs):
    trainer = Trainer(use_cuda, wandb_name)
    trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs)
    trainer.setup_wandb(
        WANDB_PROJECT,
        wandb_name,
        config={
            "Batch Size": batch_size,
            "Epochs": num_epochs,
            "Adam Betas": ADAM_BETAS,
            "Learning Rate": [MIN_LR, MAX_LR],
            "Weight Decay": WEIGHT_DECAY,
            "Fine Tuning": False,
        },
    )
    train_loader, test_loader = trainer.load_data_loaders(
        Dataset, batch_size, subsample)
    trainer.register_loss_fn(get_ce_loss)
    trainer.register_metric_fn(get_ce_metric, "Loss")
    trainer.register_metric_fn(get_accuracy_metric, "Accuracy")
    trainer.input_shape = [1, 80, 256]
    trainer.output_shape = [15]
    net = trainer.load_net(SpectralSceneNet)
    optimizer = trainer.load_optimizer(net,
                                       learning_rate=MIN_LR,
                                       adam_betas=ADAM_BETAS,
                                       weight_decay=WEIGHT_DECAY)

    # One cycle learning rate
    steps_per_epoch = len(trainer.train_set) // batch_size
    trainer.use_one_cycle_lr_scheduler(optimizer, steps_per_epoch, num_epochs,
                                       MAX_LR)

    trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
Ejemplo n.º 8
0
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample,
          checkpoint_epochs):

    # Load loss net
    loss_net = load_checkpoint(LOSS_NET_CHECKPOINT, use_cuda=use_cuda)
    loss_net.set_feature_mode()
    loss_net.eval()

    feature_loss = AudioFeatureLoss(loss_net, use_cuda=use_cuda)

    def get_feature_loss(inputs, outputs, targets):
        return feature_loss(inputs, outputs, targets)

    def get_feature_loss_metric(inputs, outputs, targets):
        loss_t = feature_loss(inputs, outputs, targets)
        return loss_t.data.item()

    trainer = Trainer(num_epochs, wandb_name)
    trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs)
    trainer.setup_wandb(
        WANDB_PROJECT,
        wandb_name,
        config={
            "Batch Size": batch_size,
            "Epochs": num_epochs,
            "Adam Betas": ADAM_BETAS,
            "Learning Rate": LEARNING_RATE,
            "Weight Decay": WEIGHT_DECAY,
            "Fine Tuning": False,
        },
    )
    train_loader, test_loader = trainer.load_data_loaders(
        Dataset, batch_size, subsample)
    trainer.register_loss_fn(get_feature_loss)
    trainer.register_metric_fn(get_mse_metric, "Loss")
    trainer.register_metric_fn(get_feature_loss_metric, "Feature Loss")
    trainer.input_shape = [2**15]
    trainer.target_shape = [2**15]
    trainer.output_shape = [2**15]
    net = trainer.load_net(WaveUNet)
    optimizer = trainer.load_optimizer(
        net,
        learning_rate=LEARNING_RATE,
        adam_betas=ADAM_BETAS,
        weight_decay=WEIGHT_DECAY,
    )
    trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
 def reconstruction_loss(x, x_recons):
     return Trainer.mean_crossentropy_loss(weights=x_recons, targets=x)
Ejemplo n.º 10
0
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs):
    batch_size = BATCH_SIZE
    trainer = Trainer(use_cuda, wandb_name)
    trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs)
    trainer.setup_wandb(
        WANDB_PROJECT,
        wandb_name,
        config={
            "Batch Size": batch_size,
            "Epochs": num_epochs,
            "Adam Betas": ADAM_BETAS,
            "Learning Rate": LEARNING_RATE,
            "Weight Decay": WEIGHT_DECAY,
            "Fine Tuning": False,
        },
    )
    train_loader, test_loader = trainer.load_data_loaders(Dataset, batch_size, subsample)
    trainer.register_loss_fn(get_ce_loss)
    trainer.register_metric_fn(get_ce_metric, "Loss")
    trainer.input_shape = [32767]
    net = trainer.load_net(SceneNet)
    optimizer = trainer.load_optimizer(
        net, learning_rate=LEARNING_RATE, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY
    )
    trainer.train(net, num_epochs, optimizer, train_loader, test_loader)

    # Do a fine tuning run with 1/10th learning rate for 1/3rd epochs.
    optimizer = trainer.load_optimizer(
        net, learning_rate=LEARNING_RATE / 10, adam_betas=ADAM_BETAS, weight_decay=WEIGHT_DECAY / 10
    )
    num_epochs = num_epochs // 3
    trainer.train(net, num_epochs, optimizer, train_loader, test_loader)
Ejemplo n.º 11
0
def main():
    """
    main
    """
    config = model_config()
    if config.check:
        config.save_dir = "./tmp/"
    config.use_gpu = torch.cuda.is_available() and config.gpu >= 0
    device = config.gpu
    torch.cuda.set_device(device)

    # Data definition
    train_file = "%s/json.train.txt" % config.data_dir
    dev_file = "%s/json.dev.txt" % config.data_dir
    test_file = "%s/json.test.txt" % config.data_dir
    vocab_file = "%s/train.vocab" % config.data_dir

    word2index, index2word = load_vocab(vocab_file, config.vocab_size)
    vocab_size = len(word2index)

    train_iter = prepare_batcher(train_file,
                                 word2index,
                                 batch_size=config.batch_size,
                                 is_shuffle=True)
    valid_iter = prepare_batcher(dev_file,
                                 word2index,
                                 batch_size=config.batch_size,
                                 is_shuffle=True)
    test_iter = prepare_batcher(test_file,
                                word2index,
                                batch_size=config.batch_size,
                                is_shuffle=False)

    # Model definition
    model = GLMP(index2word,
                 vocab_size=vocab_size,
                 hidden_size=config.hidden_size,
                 embed_dim=config.embed_size,
                 max_resp_len=config.max_dec_len,
                 n_layers=config.num_layers,
                 hop=config.hop,
                 dropout=config.dropout,
                 teacher_forcing_ratio=config.teacher_forcing_ratio,
                 use_cuda=config.use_gpu,
                 use_record=config.use_record,
                 unk_mask=config.unk_mask)

    # Testing
    if config.test and config.ckpt:
        print(model)
        model.load(save_dir=config.save_dir, file_prefix=config.ckpt)

        print("Generating ...")
        if not os.path.exists(config.output_dir):
            os.makedirs(config.output_dir)
        model.generate(test_iter, output_dir=config.output_dir, verbose=True)
    else:
        # Save directory
        if not os.path.exists(config.save_dir):
            os.makedirs(config.save_dir)

        # Optimizer definition
        optimizer = getattr(torch.optim, config.optimizer)(model.parameters(),
                                                           lr=config.lr)
        # Learning rate scheduler
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer,
            factor=config.lr_decay,
            patience=1,
            verbose=True,
            min_lr=1e-5)

        # Logger definition
        logger = logging.getLogger()
        logging.basicConfig(level=logging.DEBUG, format="%(message)s")
        fh = logging.FileHandler(os.path.join(config.save_dir, "train.log"))
        logger.addHandler(fh)
        # Save config
        params_file = os.path.join(config.save_dir, "params.json")
        with open(params_file, 'w') as fp:
            json.dump(config.__dict__, fp, indent=4, sort_keys=True)
        print("Saved params to '{}'".format(params_file))
        logger.info(model)

        # Train
        logger.info("Training starts ...")
        trainer = Trainer(model=model,
                          optimizer=optimizer,
                          train_iter=train_iter,
                          valid_iter=valid_iter,
                          logger=logger,
                          valid_metric_name="-loss",
                          num_epochs=config.num_epochs,
                          save_dir=config.save_dir,
                          log_steps=config.log_steps,
                          valid_steps=config.valid_steps,
                          grad_clip=config.grad_clip,
                          lr_scheduler=lr_scheduler)
        if config.ckpt is not None:
            trainer.load(save_dir=config.save_dir, file_prefix=config.ckpt)
        trainer.train()
        logger.info("Training done!")
Ejemplo n.º 12
0
def train(runtime, training, logging):

    # Load feature loss net
    loss_net = load_checkpoint(LOSS_NET_CHECKPOINT, use_cuda=runtime["cuda"])
    loss_net.set_feature_mode(num_layers=6)
    loss_net.eval()
    feature_loss = AudioFeatureLoss(loss_net, use_cuda=runtime["cuda"])

    def get_feature_loss(inputs, outputs, targets):
        return feature_loss(inputs, outputs, targets)

    def get_feature_loss_metric(inputs, outputs, targets):
        loss_t = feature_loss(inputs, outputs, targets)
        return loss_t.data.item()

    batch_size = training["batch_size"]
    epochs = training["epochs"]
    subsample = training["subsample"]
    trainer = Trainer(**runtime)
    trainer.setup_checkpoints(**logging["checkpoint"])
    trainer.setup_wandb(**logging["wandb"],
                        run_info={
                            "Batch Size": batch_size,
                            "Epochs": epochs,
                            "Adam Betas": ADAM_BETAS,
                            "Learning Rate": [MIN_LR, MAX_LR],
                            "Weight Decay": WEIGHT_DECAY,
                            "Fine Tuning": False,
                        })

    train_loader, test_loader = trainer.load_data_loaders(
        Dataset, batch_size, subsample)

    trainer.register_loss_fn(get_feature_loss)
    trainer.register_metric_fn(get_mse_metric, "Loss")
    trainer.register_metric_fn(get_feature_loss_metric, "Feature Loss")

    trainer.input_shape = [1, 80, 256]
    trainer.target_shape = [1, 80, 256]
    trainer.output_shape = [1, 80, 256]
    net = trainer.load_net(SpectralUNet)

    optimizer = trainer.load_optimizer(net,
                                       learning_rate=MIN_LR,
                                       adam_betas=ADAM_BETAS,
                                       weight_decay=WEIGHT_DECAY)
    steps_per_epoch = len(trainer.train_set) // batch_size
    trainer.use_one_cycle_lr_scheduler(optimizer, steps_per_epoch, epochs,
                                       MAX_LR)

    trainer.train(net, epochs, optimizer, train_loader, test_loader)
Ejemplo n.º 13
0
def train(num_epochs, use_cuda, batch_size, wandb_name, subsample, checkpoint_epochs):
    trainer = Trainer(num_epochs, wandb_name)
    trainer.setup_checkpoints(CHECKPOINT_NAME, checkpoint_epochs)
    trainer.setup_wandb(
        WANDB_PROJECT,
        wandb_name,
        config={
            "Batch Size": batch_size,
            "Epochs": num_epochs,
            "Adam Betas": ADAM_BETAS,
            "Learning Rate": LEARNING_RATE,
            "Disc Learning Rate": DISC_LEARNING_RATE,
            "Disc Weight": DISC_WEIGHT,
            "Weight Decay": WEIGHT_DECAY,
            "Fine Tuning": False,
        },
    )

    # Construct generator network
    gen_net = trainer.load_net(WaveUNet)
    gen_optimizer = trainer.load_optimizer(
        gen_net,
        learning_rate=LEARNING_RATE,
        adam_betas=ADAM_BETAS,
        weight_decay=WEIGHT_DECAY,
    )
    train_loader, test_loader = trainer.load_data_loaders(
        NoisySpeechDataset, batch_size, subsample
    )

    # Construct discriminator network
    disc_net = trainer.load_net(MelDiscriminatorNet)
    disc_loss = LeastSquaresLoss(disc_net)
    disc_optimizer = trainer.load_optimizer(
        disc_net,
        learning_rate=DISC_LEARNING_RATE,
        adam_betas=ADAM_BETAS,
        weight_decay=WEIGHT_DECAY,
    )

    # First, train generator using MSE loss
    disc_net.freeze()
    gen_net.unfreeze()
    trainer.register_loss_fn(get_mse_loss)
    trainer.register_metric_fn(get_mse_metric, "Loss")
    trainer.input_shape = [2 ** 15]
    trainer.target_shape = [2 ** 15]
    trainer.output_shape = [2 ** 15]
    trainer.train(gen_net, num_epochs, gen_optimizer, train_loader, test_loader)

    # Next, train GAN using the output of the generator
    def get_disc_loss(_, fake_audio, real_audio):
        """
        We want to compare the inputs (real audio) with the generated outout (fake audio)
        """
        return disc_loss.for_discriminator(real_audio, fake_audio)

    def get_disc_metric(_, fake_audio, real_audio):
        loss_t = disc_loss.for_discriminator(real_audio, fake_audio)
        return loss_t.data.item()

    disc_net.unfreeze()
    gen_net.freeze()
    trainer.loss_fns = []
    trainer.metric_fns = []
    trainer.register_loss_fn(get_disc_loss)
    trainer.register_metric_fn(get_disc_metric, "Discriminator Loss")
    trainer.train(gen_net, num_epochs, disc_optimizer, train_loader, test_loader)

    # Finally, train the generator using the discriminator and MSE loss
    def get_gen_loss(_, fake_audio, real_audio):
        return disc_loss.for_generator(real_audio, fake_audio)

    def get_gen_metric(_, fake_audio, real_audio):
        loss_t = disc_loss.for_generator(real_audio, fake_audio)
        return loss_t.data.item()

    disc_net.freeze()
    gen_net.unfreeze()
    trainer.loss_fns = []
    trainer.metric_fns = []
    trainer.register_loss_fn(get_mse_loss)
    trainer.register_loss_fn(get_gen_loss, weight=DISC_WEIGHT)
    trainer.register_metric_fn(get_mse_metric, "Loss")
    trainer.register_metric_fn(get_gen_metric, "Generator Loss")
    trainer.train(gen_net, num_epochs, gen_optimizer, train_loader, test_loader)