Exemplo n.º 1
0
def main():
    # setup logger
    if args.resume_dir == "":
        date = str(datetime.datetime.now())
        date = date[:date.rfind(":")].replace("-", "") \
            .replace(":", "") \
            .replace(" ", "_")
        log_dir = os.path.join(args.log_root, "log_" + date)
    else:
        log_dir = args.resume_dir
    hparams_file = os.path.join(log_dir, "hparams.json")
    checkpoints_dir = os.path.join(log_dir, "checkpoints")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)
    if args.resume_dir == "":
        # write hparams
        with open(hparams_file, "w") as f:
            json.dump(args.__dict__, f, indent=2)
    log_file = os.path.join(log_dir, "log_train.txt")
    logger = Logger(log_file)
    # logger.info(args)
    logger.info("The args corresponding to training process are: ")
    for (key, value) in vars(args).items():
        logger.info("{key:20}: {value:}".format(key=key, value=value))

    # --------------------------------------------------------------------------------------------
    #   INSTANTIATE VOCABULARY, DATALOADER, MODEL, OPTIMIZER
    # --------------------------------------------------------------------------------------------

    train_dataset = COCO_Search18(args.img_dir,
                                  args.fix_dir,
                                  args.detector_dir,
                                  blur_sigma=args.blur_sigma,
                                  type="train",
                                  split="split3",
                                  transform=transform,
                                  detector_threshold=args.detector_threshold)
    train_dataset_rl = COCO_Search18_rl(
        args.img_dir,
        args.fix_dir,
        args.detector_dir,
        type="train",
        split="split3",
        transform=transform,
        detector_threshold=args.detector_threshold)
    validation_dataset = COCO_Search18_evaluation(
        args.img_dir,
        args.fix_dir,
        args.detector_dir,
        type="validation",
        split="split3",
        transform=transform,
        detector_threshold=args.detector_threshold)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch,
                              shuffle=True,
                              num_workers=1,
                              collate_fn=train_dataset.collate_func)
    train_rl_loader = DataLoader(dataset=train_dataset_rl,
                                 batch_size=args.batch // 4,
                                 shuffle=True,
                                 num_workers=4,
                                 collate_fn=train_dataset_rl.collate_func)
    validation_loader = DataLoader(dataset=validation_dataset,
                                   batch_size=args.batch,
                                   shuffle=False,
                                   num_workers=4,
                                   collate_fn=validation_dataset.collate_func)

    model = baseline(embed_size=512,
                     convLSTM_length=args.max_length,
                     min_length=args.min_length).cuda()

    sampling = Sampling(convLSTM_length=args.max_length,
                        min_length=args.min_length)

    # optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=args.weight_decay, nesterov=True)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=args.weight_decay)

    # --------------------------------------------------------------------------------------------
    #  BEFORE TRAINING STARTS
    # --------------------------------------------------------------------------------------------

    # Tensorboard summary writer for logging losses and metrics.
    tensorboard_writer = SummaryWriter(log_dir=log_dir)

    # Record manager for writing and loading the best metrics and theirs corresponding epoch
    record_manager = RecordManager(log_dir)
    if args.resume_dir == '':
        record_manager.init_record()
    else:
        record_manager.load()

    start_epoch = record_manager.get_epoch()
    iteration = record_manager.get_iteration()
    best_metric = record_manager.get_best_metric()

    # Checkpoint manager to serialize checkpoints periodically while training and keep track of
    # best performing checkpoint.
    checkpoint_manager = CheckpointManager(model,
                                           optimizer,
                                           checkpoints_dir,
                                           mode="max",
                                           best_metric=best_metric)

    # Load checkpoint to resume training from there if specified.
    # Infer iteration number through file name (it's hacky but very simple), so don't rename
    # saved checkpoints if you intend to continue training.
    if args.resume_dir != "":
        training_checkpoint = torch.load(
            os.path.join(checkpoints_dir, "checkpoint.pth"))
        for key in training_checkpoint:
            if key == "optimizer":
                optimizer.load_state_dict(training_checkpoint[key])
            else:
                model.load_state_dict(training_checkpoint[key])

    # lr_scheduler = optim.lr_scheduler.LambdaLR \
    #     (optimizer, lr_lambda=lambda iteration: 1 - iteration / (len(train_loader) * args.epoch), last_epoch=iteration)

    def lr_lambda(iteration):
        if iteration <= len(train_loader) * args.warmup_epoch:
            return iteration / (len(train_loader) * args.warmup_epoch)
        elif iteration <= len(train_loader) * args.start_rl_epoch:
            return 1 - (iteration - len(train_loader) * args.warmup_epoch) /\
                   (len(train_loader) * (args.start_rl_epoch - args.warmup_epoch))
        else:
            return args.rl_lr_initial_decay * (
                1 - (iteration - (len(train_loader) * args.start_rl_epoch)) /
                (len(train_rl_loader) * (args.epoch - args.start_rl_epoch)))
            pass

    lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                               lr_lambda=lr_lambda,
                                               last_epoch=iteration)

    if len(args.gpu_ids) > 1:
        model = nn.DataParallel(model, args.gpu_ids)

    def train(iteration, epoch):
        # traditional training stage
        if epoch < args.start_rl_epoch:
            model.train()
            for i_batch, batch in enumerate(train_loader):
                tmp = [
                    batch["images"], batch["scanpaths"], batch["durations"],
                    batch["action_masks"], batch["duration_masks"],
                    batch["attention_maps"], batch["tasks"]
                ]
                tmp = [_ if not torch.is_tensor(_) else _.cuda() for _ in tmp]
                images, scanpaths, durations, action_masks, duration_masks, attention_maps, tasks = tmp

                optimizer.zero_grad()

                if args.ablate_attention_info:
                    attention_maps *= 0

                predicts = model(images, attention_maps, tasks)

                loss_actions = CrossEntropyLoss(predicts["all_actions_prob"],
                                                scanpaths, action_masks)
                loss_duration = MLPLogNormalDistribution(
                    predicts["log_normal_mu"], predicts["log_normal_sigma2"],
                    durations, duration_masks)

                loss = loss_actions + args.lambda_1 * loss_duration

                loss.backward()
                if args.clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.clip)
                optimizer.step()

                iteration += 1
                lr_scheduler.step()
                pbar.update(1)
                # Log loss and learning rate to tensorboard.
                tensorboard_writer.add_scalar("loss/loss", loss, iteration)
                tensorboard_writer.add_scalar("loss/loss_actions",
                                              loss_actions, iteration)
                tensorboard_writer.add_scalar("loss/loss_duration",
                                              loss_duration, iteration)
                tensorboard_writer.add_scalar("learning_rate",
                                              optimizer.param_groups[0]["lr"],
                                              iteration)
        # reinforcement learning stage
        else:
            model.eval()
            # create a ScanMatch object
            ScanMatchwithDuration = ScanMatch(Xres=320,
                                              Yres=240,
                                              Xbin=16,
                                              Ybin=12,
                                              Offset=(0, 0),
                                              TempBin=50,
                                              Threshold=3.5)
            ScanMatchwithoutDuration = ScanMatch(Xres=320,
                                                 Yres=240,
                                                 Xbin=16,
                                                 Ybin=12,
                                                 Offset=(0, 0),
                                                 Threshold=3.5)
            for i_batch, batch in enumerate(train_rl_loader):

                tmp = [
                    batch["images"], batch["fix_vectors"],
                    batch["attention_maps"], batch["tasks"]
                ]
                tmp = [_ if not torch.is_tensor(_) else _.cuda() for _ in tmp]
                images, gt_fix_vectors, attention_maps, tasks = tmp
                N, C, H, W = images.shape
                optimizer.zero_grad()

                if args.ablate_attention_info:
                    attention_maps *= 0

                metrics_reward_batch = []
                neg_log_actions_batch = []
                neg_log_durations_batch = []

                # get the random sample prediction
                predict = model(images, attention_maps, tasks)
                log_normal_mu = predict["log_normal_mu"]
                log_normal_sigma2 = predict["log_normal_sigma2"]
                all_actions_prob = predict["all_actions_prob"]

                trial = 0
                while True:
                    if trial >= args.rl_sample_number:
                        break

                    samples = sampling.random_sample(all_actions_prob,
                                                     log_normal_mu,
                                                     log_normal_sigma2)

                    prob_sample_actions = samples["selected_actions_probs"]
                    durations = samples["durations"]
                    sample_actions = samples["selected_actions"]
                    random_predict_fix_vectors, action_masks, duration_masks = sampling.generate_scanpath(
                        images, prob_sample_actions, durations, sample_actions)
                    t = durations.data.clone()

                    metrics_reward = pairs_eval_scanmatch(
                        gt_fix_vectors, random_predict_fix_vectors,
                        ScanMatchwithDuration, ScanMatchwithoutDuration)

                    if np.any(np.isnan(metrics_reward)):
                        continue
                    else:
                        trial += 1
                        metrics_reward = torch.tensor(metrics_reward,
                                                      dtype=torch.float32).to(
                                                          images.get_device())
                        neg_log_actions = -LogAction(prob_sample_actions,
                                                     action_masks)
                        neg_log_durations = -LogDuration(
                            t, log_normal_mu, log_normal_sigma2,
                            duration_masks)
                        metrics_reward_batch.append(
                            metrics_reward.unsqueeze(0))
                        neg_log_actions_batch.append(
                            neg_log_actions.unsqueeze(0))
                        neg_log_durations_batch.append(
                            neg_log_durations.unsqueeze(0))

                neg_log_actions_tensor = torch.cat(neg_log_actions_batch,
                                                   dim=0)
                neg_log_durations_tensor = torch.cat(neg_log_durations_batch,
                                                     dim=0)
                # use the mean as reward
                metrics_reward_tensor = torch.cat(metrics_reward_batch, dim=0)
                metrics_reward_hmean = scipy.stats.hmean(
                    metrics_reward_tensor[:, :, :].cpu(), axis=-1)
                metrics_reward_hmean_tensor = torch.tensor(
                    metrics_reward_hmean).to(
                        metrics_reward_tensor.get_device())
                baseline_reward_hmean_tensor = metrics_reward_hmean_tensor.mean(
                    0, keepdim=True)

                loss_actions = (neg_log_actions_tensor *
                                (metrics_reward_hmean_tensor -
                                 baseline_reward_hmean_tensor)).sum()
                loss_duration = (neg_log_durations_tensor *
                                 (metrics_reward_hmean_tensor -
                                  baseline_reward_hmean_tensor)).sum()
                loss = loss_actions + loss_duration

                loss.backward()
                if args.clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.clip)
                optimizer.step()

                iteration += 1
                lr_scheduler.step()
                pbar.update(1)
                # Log loss and learning rate to tensorboard.
                multimatch_metric_names = ["w/o duration", "w/ duration"]
                multimatch_metrics_reward = metrics_reward_tensor.mean(0).mean(
                    0)
                tensorboard_writer.add_scalar("rl_loss", loss, iteration)
                tensorboard_writer.add_scalar("reward_hmean",
                                              metrics_reward_hmean.mean(),
                                              iteration)
                tensorboard_writer.add_scalar("learning_rate",
                                              optimizer.param_groups[0]["lr"],
                                              iteration)
                for metric_index in range(len(multimatch_metric_names)):
                    tensorboard_writer.add_scalar(
                        "metrics_for_reward/{metric_name}".format(
                            metric_name=multimatch_metric_names[metric_index]),
                        multimatch_metrics_reward[metric_index], iteration)

        return iteration

    def validation(iteration):
        model.eval()
        repeat_num = args.eval_repeat_num
        all_gt_fix_vectors = []
        all_predict_fix_vectors = []
        with tqdm(total=len(validation_loader) * repeat_num) as pbar_val:
            for i_batch, batch in enumerate(validation_loader):
                tmp = [
                    batch["images"], batch["fix_vectors"],
                    batch["attention_maps"], batch["tasks"]
                ]
                tmp = [_ if not torch.is_tensor(_) else _.cuda() for _ in tmp]
                images, gt_fix_vectors, attention_maps, tasks = tmp

                if args.ablate_attention_info:
                    attention_maps *= 0

                with torch.no_grad():
                    predict = model(images, attention_maps, tasks)

                log_normal_mu = predict["log_normal_mu"]
                log_normal_sigma2 = predict["log_normal_sigma2"]
                all_actions_prob = predict["all_actions_prob"]

                for trial in range(repeat_num):
                    all_gt_fix_vectors.extend(gt_fix_vectors)

                    samples = sampling.random_sample(all_actions_prob,
                                                     log_normal_mu,
                                                     log_normal_sigma2)
                    prob_sample_actions = samples["selected_actions_probs"]
                    durations = samples["durations"]
                    sample_actions = samples["selected_actions"]
                    sampling_random_predict_fix_vectors, _, _ = sampling.generate_scanpath(
                        images, prob_sample_actions, durations, sample_actions)
                    all_predict_fix_vectors.extend(
                        sampling_random_predict_fix_vectors)

                    pbar_val.update(1)

        cur_metrics, cur_metrics_std, _ = evaluation(all_gt_fix_vectors,
                                                     all_predict_fix_vectors)

        # Print and log all evaluation metrics to tensorboard.
        logger.info("Evaluation metrics after iteration {iteration}:".format(
            iteration=iteration))
        for metrics_key in cur_metrics.keys():
            for (metric_name,
                 metric_value) in cur_metrics[metrics_key].items():
                tensorboard_writer.add_scalar(
                    "metrics/{metrics_key}-{metric_name}".format(
                        metrics_key=metrics_key, metric_name=metric_name),
                    metric_value, iteration)
                logger.info(
                    "{metrics_key:10}-{metric_name:15}: {metric_value:.4f} +- {std:.4f}"
                    .format(metrics_key=metrics_key,
                            metric_name=metric_name,
                            metric_value=metric_value,
                            std=cur_metrics_std[metrics_key][metric_name]))

        return cur_metrics

    # get the human baseline score
    human_metrics, human_metrics_std, _ = human_evaluation(validation_loader)
    logger.info("The metrics for human performance are: ")
    for metrics_key in human_metrics.keys():
        for (key, value) in human_metrics[metrics_key].items():
            logger.info(
                "{metrics_key:10}-{key:15}: {value:.4f} +- {std:.4f}".format(
                    metrics_key=metrics_key,
                    key=key,
                    value=value,
                    std=human_metrics_std[metrics_key][key]))

    tqdm_total = len(train_loader) * args.start_rl_epoch + len(
        train_rl_loader) * (args.epoch - args.start_rl_epoch)
    with tqdm(total=tqdm_total, initial=iteration + 1) as pbar:
        for epoch in range(start_epoch + 1, args.epoch):
            iteration = train(iteration, epoch)
            cur_metrics = validation(iteration)
            cur_metric = scipy.stats.hmean(
                list(cur_metrics["ScanMatch"].values()))

            # Log current metric to tensorboard.
            tensorboard_writer.add_scalar("current metric", float(cur_metric),
                                          iteration)
            logger.info("{key:10}: {value:.4f}".format(
                key="current metric", value=float(cur_metric)))

            # save
            checkpoint_manager.step(float(cur_metric))
            best_metric = checkpoint_manager.get_best_metric()
            record_manager.save(epoch, iteration, best_metric)

            # check  whether to save the final supervised training file
            if args.supervised_save and epoch == args.start_rl_epoch - 1:
                cmd = 'cp -r ' + log_dir + ' ' + log_dir + '_supervised_save'
                os.system(cmd)
Exemplo n.º 2
0
class Summarization(object):
    def __init__(self, hparams, mode='train'):
        self.hparams = hparams
        self._logger = logging.getLogger(__name__)
        print('self.hparams:', self.hparams)
        self.logger = logging.getLogger(__name__)

        if hparams.device == 'cuda':
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')

        self.build_dataloader()

        self.save_dirpath = self.hparams.save_dirpath
        today = str(datetime.today().month) + 'M_' + str(
            datetime.today().day) + 'D' + '_GEN_MAX_' + str(
                self.hparams.gen_max_length)
        tensorboard_path = self.save_dirpath + today
        self.summary_writer = SummaryWriter(tensorboard_path, comment="Unmt")

        if mode == 'train':
            self.build_model()
            self.setup_training()
            self.predictor = self.build_eval_model(
                model=self.model, summary_writer=self.summary_writer)
            dump_vocab(self.hparams.save_dirpath + 'vocab_word',
                       self.vocab_word)

        elif mode == 'eval':
            self.predictor = self.build_eval_model(
                summary_writer=self.summary_writer)

    def build_dataloader(self):
        self.train_dataset = AMIDataset(self.hparams, type='train')
        self.train_dataloader = DataLoader(self.train_dataset,
                                           batch_size=self.hparams.batch_size,
                                           num_workers=self.hparams.workers,
                                           shuffle=True,
                                           drop_last=True)
        self.vocab_word = self.train_dataset.vocab_word
        self.vocab_role = self.train_dataset.vocab_role
        self.vocab_pos = self.train_dataset.vocab_pos

        self.test_dataset = AMIDataset(self.hparams,
                                       type='test',
                                       vocab_word=self.vocab_word,
                                       vocab_role=self.vocab_role,
                                       vocab_pos=self.vocab_pos)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=self.hparams.batch_size,
                                          num_workers=self.hparams.workers,
                                          drop_last=False)

    print("""
           # -------------------------------------------------------------------------
           #   DATALOADER FINISHED
           # -------------------------------------------------------------------------
           """)

    def build_model(self):
        # Define model
        self.model = SummarizationModel(hparams=self.hparams,
                                        vocab_word=self.vocab_word,
                                        vocab_role=self.vocab_role,
                                        vocab_pos=self.vocab_pos)

        # Multi-GPU
        self.model = self.model.to(self.device)

        # Use Multi-GPUs
        if -1 not in self.hparams.gpu_ids and len(self.hparams.gpu_ids) > 1:
            self.model = nn.DataParallel(self.model, self.hparams.gpu_ids)

        # Define Loss and Optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.hparams.learning_rate,
                                    betas=(self.hparams.optimizer_adam_beta1,
                                           self.hparams.optimizer_adam_beta2))

    def setup_training(self):
        self.save_dirpath = self.hparams.save_dirpath
        today = str(datetime.today().month) + 'M_' + str(
            datetime.today().day) + 'D'
        tensorboard_path = self.save_dirpath + today
        self.summary_writer = SummaryWriter(tensorboard_path, comment="Unmt")
        self.checkpoint_manager = CheckpointManager(self.model,
                                                    self.optimizer,
                                                    self.save_dirpath,
                                                    hparams=self.hparams)

        # If loading from checkpoint, adjust start epoch and load parameters.
        if self.hparams.load_pthpath == "":
            self.start_epoch = 1
        else:
            # "path/to/checkpoint_xx.pth" -> xx
            self.start_epoch = int(
                self.hparams.load_pthpath.split("_")[-1][:-4])
            self.start_epoch += 1
            model_state_dict, optimizer_state_dict = load_checkpoint(
                self.hparams.load_pthpath)
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(model_state_dict,
                                                  strict=True)
            else:
                self.model.load_state_dict(model_state_dict)

            self.optimizer.load_state_dict(optimizer_state_dict, strict=True)
            self.previous_model_path = self.hparams.load_pthpath
            print("Loaded model from {}".format(self.hparams.load_pthpath))

        print("""
            # -------------------------------------------------------------------------
            #   Setup Training Finished
            # -------------------------------------------------------------------------
            """)

    def build_eval_model(self,
                         model=None,
                         summary_writer=None,
                         eval_path=None):
        # Define predictor
        predictor = Predictor(self.hparams,
                              model=model,
                              vocab_word=self.vocab_word,
                              vocab_role=self.vocab_role,
                              vocab_pos=self.vocab_pos,
                              checkpoint=eval_path,
                              summary_writer=summary_writer)

        return predictor

    def train(self):
        train_begin = datetime.utcnow()  # News
        global_iteration_step = 0
        for epoch in range(self.hparams.num_epochs):
            self.model.train()
            tqdm_batch_iterator = tqdm(self.train_dataloader)
            for batch_idx, batch in enumerate(tqdm_batch_iterator):
                data = batch
                dialogues_ids = data['dialogues_ids'].to(self.device)
                pos_ids = data['pos_ids'].to(self.device)
                labels_ids = data['labels_ids'].to(
                    self.device)  # [batch==1, tgt_seq_len]
                src_masks = data['src_masks'].to(self.device)
                role_ids = data['role_ids'].to(self.device)

                logits = self.model(
                    inputs=dialogues_ids,
                    targets=labels_ids[:, :-1],  # before <END> token
                    src_masks=src_masks,
                    role_ids=role_ids,
                    pos_ids=pos_ids)  # [batch x tgt_seq_len, vocab_size]

                labels_ids = labels_ids[:, 1:]
                labels_ids = labels_ids.view(
                    labels_ids.shape[0] *
                    labels_ids.shape[1])  # [batch x tgt_seq_len]

                loss = self.criterion(logits, labels_ids)
                loss.backward()

                # gradient cliping
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.hparams.max_gradient_norm)
                self.optimizer.step()
                self.optimizer.zero_grad()

                global_iteration_step += 1
                description = "[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][lr: {:7f}]".format(
                    datetime.utcnow() - train_begin, epoch,
                    global_iteration_step, loss,
                    self.optimizer.param_groups[0]['lr'])
                tqdm_batch_iterator.set_description(description)

            # # -------------------------------------------------------------------------
            # #   ON EPOCH END  (checkpointing and validation)
            # # -------------------------------------------------------------------------
            self.checkpoint_manager.step(epoch)
            self.previous_model_path = os.path.join(
                self.checkpoint_manager.ckpt_dirpath,
                "checkpoint_%d.pth" % (epoch))
            self._logger.info(self.previous_model_path)

            # torch.cuda.empty_cache()

            if epoch % 10 == 0 and epoch >= self.hparams.start_eval_epoch:
                print('======= Evaluation Start Epoch: ', epoch,
                      ' ==================')

                self.predictor.evaluate(test_dataloader=self.test_dataloader,
                                        epoch=epoch,
                                        eval_path=self.previous_model_path)

                print(
                    '============================================================\n\n'
                )
Exemplo n.º 3
0
def train():

    #################################################
    # Argparse stuff click was a bad idea after all #
    #################################################

    parser = argparse.ArgumentParser()

    parser.add_argument('--config-json',
                        type=str,
                        help="The json file specifying the args below")

    parser.add_argument('--embedder-path',
                        type=str,
                        help="Path to the embedder checkpoint." +
                        " Example: 'embedder/data/best_model'")

    group_chk = parser.add_argument_group('checkpointing')
    group_chk.add_argument('--epoch-save-interval',
                           type=int,
                           help="After every [x] epochs save w/" +
                           "checkpoint manager")
    group_chk.add_argument('--save-dir',
                           type=str,
                           help="Relative path of save directory, " +
                           "include the trailing /")
    group_chk.add_argument("--load-dir",
                           type=str,
                           help="Checkpoint prefix directory to " +
                           "load initial model from")

    group_system = parser.add_argument_group('system')
    group_system.add_argument('--cpu-workers',
                              type=int,
                              help="Number of CPU workers for dataloader")
    group_system.add_argument('--torch-seed',
                              type=int,
                              help="Seed for for torch and torch_cudnn")
    group_system.add_argument('--gpu-ids',
                              help="The GPU ID to use. If -1, use CPU")

    group_data = parser.add_argument_group('data')
    group_data.add_argument('--mel-size',
                            type=int,
                            help="Number of channels in the mel-gram")
    group_data.add_argument('--style-size',
                            type=int,
                            help="Dimensionality of style vector")
    group_data.add_argument('--dset-num-people',
                            type=int,
                            help="If using VCTK, an integer under 150")
    group_data.add_argument('--dset-num-samples',
                            type=int,
                            help="If using VCTK, an integer under 300")
    group_data.add_argument('--mel-root',
                            default='data/taco/',
                            type=str,
                            help='Path to the directory (include last /) ' +
                            'where the person mel folders are')

    group_training = parser.add_argument_group('training')
    group_training.add_argument('--num-epochs',
                                type=int,
                                help="The number of epochs to train for")
    group_training.add_argument(
        '--lr-dtor-isvoice',
        type=float,
    )
    group_training.add_argument(
        '--lr-tform',
        type=float,
    )

    group_training.add_argument(
        '--num-batches-dtor-isvoice',
        type=int,
    )
    group_training.add_argument(
        '--batch-size-dtor-isvoice',
        type=int,
    )

    group_training.add_argument(
        '--num-batches-tform',
        type=int,
    )
    group_training.add_argument(
        '--batch-size-tform',
        type=int,
    )

    group_model = parser.add_argument_group('model')
    group_model.add_argument('--identity-mode', help='One of [norm, cos, nn]')

    args = parser.parse_args()
    if args.config_json is not None:
        with open(args.config_json) as json_file:
            file_args = json.load(json_file)
        cli_dict = vars(args)
        for key in cli_dict:
            if cli_dict[key] is not None:
                file_args[key] = cli_dict[key]
        args.__dict__ = file_args

    print("CLI args are: ", args)
    with open("configs/basic.yml") as f:
        config = yaml.full_load(f)

    ############################
    # Setting up the constants #
    ############################

    if args.save_dir is not None and args.save_dir[-1] != "/":
        args.save_dir += "/"
    if args.load_dir is not None and args.load_dir[-1] != "/":
        args.load_dir += "/"

    SAVE_DTOR_ISVOICE = args.save_dir + FOLDER_DTOR_IV
    SAVE_TRANSFORMER = args.save_dir + FOLDER_TRANSFORMER

    ############################
    # Reproducibility Settings #
    ############################
    # Refer to https://pytorch.org/docs/stable/notes/randomness.html
    torch.manual_seed(args.torch_seed)
    torch.cuda.manual_seed_all(args.torch_seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # TODO Enable?
    # torch.set_default_tensor_type(torch.cuda.FloatTensor)

    #############################
    # Setting up Pytorch device #
    #############################
    use_cpu = -1 == args.gpu_ids
    device = torch.device("cpu" if use_cpu else "cuda")

    ###############################################
    # Initialize the model and related optimizers #
    ###############################################

    if args.load_dir is None:
        start_epoch = 0
    else:
        start_epoch = int(args.load_dir.split("_")[-1][:-4])

    model = ProjectModel(config=config["transformer"],
                         embedder_path=args.embedder_path,
                         mel_size=args.mel_size,
                         style_size=args.style_size,
                         identity_mode=args.identity_mode,
                         cuda=(not use_cpu))
    model = model.to(device)
    tform_optimizer = torch.optim.Adam(model.transformer.parameters(),
                                       lr=args.lr_tform)
    tform_checkpointer = CheckpointManager(model.transformer, tform_optimizer,
                                           SAVE_TRANSFORMER,
                                           args.epoch_save_interval,
                                           start_epoch + 1)

    dtor_isvoice_optimizer = torch.optim.Adam(model.isvoice_dtor.parameters(),
                                              lr=args.lr_dtor_isvoice)
    dtor_isvoice_checkpointer = CheckpointManager(model.isvoice_dtor,
                                                  dtor_isvoice_optimizer,
                                                  SAVE_DTOR_ISVOICE,
                                                  args.epoch_save_interval,
                                                  start_epoch + 1)

    ###############################################

    # Load the checkpoint, if it is specified
    if args.load_dir is not None:
        tform_md, tform_od = load_checkpoint(SAVE_TRANSFORMER)
        model.transformer.load_state_dict(tform_md)
        tform_optimizer.load_state_dict(tform_od)

        dtor_isvoice_md, dtor_isvoice_od = load_checkpoint(SAVE_DTOR_ISVOICE)
        model.dtor_isvoice.load_state_dict(dtor_isvoice_md)
        tform_optimizer.load_state_dict(dtor_isvoice_od)

    ##########################
    # Declaring the datasets #
    ##########################

    dset_wrapper = VCTK_Wrapper(
        model.embedder,
        args.dset_num_people,
        args.dset_num_samples,
        args.mel_root,
        device,
    )

    if args.mel_size != dset_wrapper.mel_from_ids(0, 0).size()[-1]:
        raise RuntimeError("mel size arg is different from that in file")

    dset_isvoice_real = Isvoice_Dataset_Real(dset_wrapper, )
    dset_isvoice_fake = Isvoice_Dataset_Fake(dset_wrapper, model.embedder,
                                             model.transformer)
    dset_generator_train = Generator_Dataset(dset_wrapper, )
    # We're enforcing identity via a resnet connection for now, so unused
    # dset_identity_real = Identity_Dataset_Real(dset_wrapper,
    #                                            embedder)
    # dset_identity_fake = Identity_Dataset_Fake(dset_wrapper,
    #                                            embedder, transformer)

    collate_along_timeaxis = lambda x: collate_pad_tensors(x, pad_dim=1)
    dload_isvoice_real = DataLoader(dset_isvoice_real,
                                    batch_size=args.batch_size_dtor_isvoice,
                                    collate_fn=collate_along_timeaxis)
    dload_isvoice_fake = DataLoader(dset_isvoice_fake,
                                    batch_size=args.batch_size_dtor_isvoice,
                                    collate_fn=collate_along_timeaxis)
    dload_generator = DataLoader(dset_generator_train,
                                 batch_size=args.batch_size_tform,
                                 collate_fn=Generator_Dataset.collate_fn)

    #######################################################
    # The actual training loop gaaah what a rollercoaster #
    #######################################################
    train_start_time = datetime.now()
    print("Started Training at {}".format(train_start_time))
    for epoch in range(args.num_epochs):
        epoch_start_time = datetime.now()
        ###############
        # (D1) Train Real vs Fake Discriminator
        ###############
        train_dtor(model.isvoice_dtor, dtor_isvoice_optimizer,
                   dload_isvoice_real, dload_isvoice_fake,
                   args.num_batches_dtor_isvoice, device)
        dtor_isvoice_checkpointer.step()
        gc.collect()

        # Train generators here
        ################
        # (G) Update Generator
        ################
        val_loss = train_gen(model,
                             tform_optimizer,
                             dload_generator,
                             device,
                             num_batches=args.num_batches_tform)
        tform_checkpointer.step()
        gc.collect()
Exemplo n.º 4
0
        # Evaluate on the validation set
        # ---------------------------------------------------------------------

        validation_loss = validate(dataloader=validation_dataloader,
                                   model=model,
                                   loss_func=loss_func,
                                   epoch=epoch,
                                   args=args)

        # ---------------------------------------------------------------------
        # Take a step with the CheckpointManager
        # ---------------------------------------------------------------------

        # This will create checkpoint if the current model is the best we've
        # seen yet, and also once every `step_size` number of epochs.
        checkpoint_manager.step(metric=validation_loss, epoch=epoch)

        # ---------------------------------------------------------------------
        # Update the learning rate of the optimizer (using the LR scheduler)
        # ---------------------------------------------------------------------

        # Take a step with the LR scheduler; print message when LR changes
        current_lr = update_lr(scheduler, optimizer, validation_loss)

        # Log the current value of the LR to TensorBoard
        if args.tensorboard:
            args.logger.add_scalar(tag='learning_rate',
                                   scalar_value=current_lr,
                                   global_step=epoch)

        # ---------------------------------------------------------------------