예제 #1
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type not in [
            'UNet', 'R2UNet', 'AUNet', 'R2AUNet', 'SEUNet', 'SEUNet++',
            'UNet++', 'DAUNet', 'DANet', 'AUNetR', 'RendDANet', "RendUNet"
    ]:
        print('ERROR!! model_type should be selected in supported models')
        print('Choose model %s' % config.model_type)
        return
    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "AUNet":
        model = AUNet()
    elif config.model_type == "R2UNet":
        model = R2UNet()
    elif config.model_type == "SEUNet":
        model = SEUNet(useCSE=False, useSSE=False, useCSSE=True)
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101', nclass=config.output_ch)
    elif config.model_type == "AUNetR":
        model = AUNet_R16(n_classes=1, learned_bilinear=True)
    elif config.model_type == "RendDANet":
        model = RendDANet(backbone='resnet101', nclass=config.output_ch)
    elif config.model_type == "RendUNet":
        model = RendUNet()
    else:
        model = UNet()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)
    print('# parameters:', sum(param.numel() for param in model.parameters()))
    if config.optimizer == "sgd":
        optimizer = SGD(model.seg.parameters(),
                        lr=1e-2,
                        weight_decay=1e-6,
                        momentum=0.9)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    if config.loss == "dice":
        criterion = DiceLoss()
    elif config.loss == "bce":
        criterion = nn.BCELoss()
    elif config.loss == "mix":
        criterion = MixLoss()
    else:
        criterion = MultiRendLoss_v10()

    scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    global_step = 0
    best_dice = 0.0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img') as train_pbar:
            model.train()
            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)
                mask = mask.to(device, dtype=torch.float32)
                output = model(image)
                loss = criterion(output, mask)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
            scheduler.step()
        epoch_dice = 0.0
        epoch_acc = 0.0
        epoch_sen = 0.0
        epoch_spe = 0.0
        epoch_pre = 0.0
        current_num = 0
        with tqdm(total=config.num_val,
                  desc="Epoch %d / %d validation round" %
                  (epoch + 1, config.num_epochs),
                  unit='img') as val_pbar:
            model.eval()
            locker = 0
            for image, mask in val_loader:
                current_num += image.shape[0]
                image = image.to(device, dtype=torch.float32)
                mask = mask.to(device, dtype=torch.float32)
                output = model(image)
                pred = torch.sigmoid(output['fine'])
                batch_dice = dice_coeff(mask, pred).item()
                epoch_dice += batch_dice * image.shape[0]
                epoch_acc += get_accuracy(pred=pred,
                                          true=mask) * image.shape[0]
                epoch_sen += get_sensitivity(pred=pred,
                                             true=mask) * image.shape[0]
                epoch_spe += get_specificity(pred=pred,
                                             true=mask) * image.shape[0]
                epoch_pre += get_precision(pred=pred,
                                           true=mask) * image.shape[0]
                if locker == 200:
                    writer.add_images('masks/true', mask, epoch + 1)
                    writer.add_images('masks/pred', pred > 0.5, epoch + 1)
                val_pbar.set_postfix(**{'dice (batch)': batch_dice})
                val_pbar.update(image.shape[0])
                locker += 1
            epoch_dice /= float(current_num)
            epoch_acc /= float(current_num)
            epoch_sen /= float(current_num)
            epoch_spe /= float(current_num)
            epoch_pre /= float(current_num)
            epoch_f1 = get_F1(SE=epoch_sen, PR=epoch_pre)
            if epoch_dice > best_dice:
                best_dice = epoch_dice
                writer.add_scalar('Best Dice/test', best_dice, epoch + 1)
                torch.save(
                    model, config.result_path + "/%s_%s_%d.pth" %
                    (config.model_type, str(epoch_dice), epoch + 1))
            logging.info('Validation Dice Coeff: {}'.format(epoch_dice))
            print("epoch dice: " + str(epoch_dice))
            writer.add_scalar('Dice/test', epoch_dice, epoch + 1)
            writer.add_scalar('Acc/test', epoch_acc, epoch + 1)
            writer.add_scalar('Sen/test', epoch_sen, epoch + 1)
            writer.add_scalar('Spe/test', epoch_spe, epoch + 1)
            writer.add_scalar('Pre/test', epoch_pre, epoch + 1)
            writer.add_scalar('F1/test', epoch_f1, epoch + 1)

    writer.close()
    print("Training finished")
예제 #2
0
    def train(self, epoch, hparam=None):
        '''
    Inputs:
      - hparam:             dictionary of hyperparameters. 
                            Save average epoch loss, train and validation accuracy to tensorboard.
                            After half of the training epoch, save model, optimizer ,and scalar state dict, current epoch, stats and config during checkpoint 
    '''

        model_start_time = self.config['model_start_time']
        previous_epoch = self.config['previous_epoch']
        check_every_epoch = self.config['check_every_epoch']

        self.config['hparam'] = hparam

        epoch += previous_epoch  # for load previous model

        if hparam:
            writer = SummaryWriter('runs/' + model_start_time)

        checkpoint_cycle_flag = True
        num_batch = len(self.train_loader)
        self.model.train()
        for i in range(previous_epoch + 1, epoch + 1):

            total_loss = 0
            iter_loss_history = []
            Y_pred_all = []
            Y_tr_all = []

            if checkpoint_cycle_flag:
                checkpoint_cycle_flag = False
                checkpoint_start_time = time.time()

            for j, data in zip(s := trange(num_batch, leave=False),
                               self.train_loader):

                Xtr, Ytr = data
                Xtr, Ytr = Xtr.to(**self.to_float_cuda,
                                  non_blocking=True), Ytr.cuda(
                                      non_blocking=True)

                ################################## Future changes ##########################################################

                loss, y_pred = self.train_fn(Xtr, Ytr)

                ############################################################################################################

                total_loss += loss

                # Iter Book keeping
                Y_pred_all.append(y_pred)
                Y_tr_all.append(Ytr)
                iter_loss_history.append(loss)

                # update progress bar
                s.set_description(f'Epoch {i}/{epoch} Loss: {loss:.4f} ')

            avg_loss = total_loss / num_batch

            # Epoch Book keeping
            self.stats['iter_loss'].append(iter_loss_history)
            self.stats['avg_loss'].append(avg_loss)

            # Enter checkpoint block after first and last epoch and specify checkpoint interval
            if i % check_every_epoch == 0 or i == epoch:
                checkpoint_cycle_flag = True
                cur_lr = self.optimizer.param_groups[0]['lr']

                # check train accuracy by using saved results during forward pass to save computation.
                Y_pred_all = torch.argmax(torch.cat(Y_pred_all), dim=1)
                Y_tr_all = torch.cat(Y_tr_all)
                train_accuracy = (Y_pred_all == Y_tr_all).float().mean()

                # check val accuracy
                val_accuracy, val_loss = self._check_accuracy(self.val_loader)

                # check update ratio
                ratio = self._check_update_ratio(cur_lr)

                print(
                    f'Epoch: {i}/{epoch}, train loss: {avg_loss:.4f}, val loss: {val_loss:.4f}, train acc: {train_accuracy:.4f}, val acc: {val_accuracy:.4f},lr: {cur_lr:.4e}, update ratio: {ratio:.2e}, took {(time.time()-checkpoint_start_time):.2f} seconds'
                )

                # Checkpoint Book keeping
                self.stats['train_acc'].append(train_accuracy)
                self.stats['val_acc'].append(val_accuracy)
                self.stats['ratio'].append(ratio)

                if hparam:
                    writer.add_scalar('Epoch loss', avg_loss, i)
                    writer.add_scalars('accuracy', {
                        'train': train_accuracy,
                        'val': val_accuracy
                    }, i)

                    # only save model checkpoint after half of the training process
                    if i > epoch // 2:
                        self._save_checkpoint(epoch=i)

            # decay learning rate after complete one epoch
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            loss = model(inputs,
                         masked_lm_labels=labels) if args.mlm else model(
                             inputs, labels=labels)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(gpu, args):
    """Create the model and start the training."""

    rank = args.nr * args.num_gpus + gpu
    if gpu == 1:
        gpu = 3

    dist.init_process_group(backend="nccl",
                            world_size=args.world_size,
                            rank=rank)

    if args.batch_size == 1 and args.use_bn is True:
        raise Exception

    torch.autograd.set_detect_anomaly(True)
    torch.manual_seed(args.torch_seed)
    torch.cuda.manual_seed(args.cuda_seed)

    torch.cuda.set_device(gpu)

    w, h = map(int, args.input_size.split(','))
    input_size = (w, h)

    w, h = map(int, args.input_size_target.split(','))
    input_size_target = (w, h)

    cudnn.enabled = True
    gpu = gpu

    criterion = DiceBCELoss()
    # criterion = nn.CrossEntropyLoss(ignore_index=253)
    # Create network
    if args.model == 'DeepLab':
        model = DeeplabMulti(num_classes=args.num_classes)
        if args.restore_from is None:
            pass
        elif args.restore_from[:4] == 'http':
            saved_state_dict = model_zoo.load_url(args.restore_from)
        elif args.restore_from is not None:
            saved_state_dict = torch.load(args.restore_from)
            model.load_state_dict(saved_state_dict)
            print("Loaded state dicts for model")
        # if args.restore_from is not None:
        #     new_params = model.state_dict().copy()
        #     for i in saved_state_dict:
        #         # Scale.layer5.conv2d_list.3.weight
        #         i_parts = i.split('.')
        #         # print i_parts
        #         if not args.num_classes == 19 or not i_parts[1] == 'layer5':
        #             new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        #             # print i_parts
        #     model.load_state_dict(new_params)

    if not args.no_logging:
        if not os.path.isdir(args.log_dir):
            os.mkdir(args.log_dir)
        log_dir = os.path.join(args.log_dir, args.exp_dir)
        if not os.path.isdir(log_dir):
            os.mkdir(log_dir)
        if args.exp_name == "":
            exp_name = datetime.datetime.now().strftime("%H%M%S-%Y%m%d")
        else:
            exp_name = args.exp_name
        log_dir = os.path.join(log_dir, exp_name)
        writer = SummaryWriter(log_dir)

    model.train()
    # model.cuda(gpu)
    model = model.cuda(device=gpu)

    if args.num_gpus > 0 or torch.cuda.device_count() > 0:
        model = DistributedDataParallel(model,
                                        device_ids=[gpu],
                                        find_unused_parameters=True)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # cudnn.benchmark = True

    # init D
    model_D1 = FCDiscriminator(num_classes=args.num_classes)
    model_D2 = FCDiscriminator(num_classes=args.num_classes)
    start_epoch = 0
    if "http" not in args.restore_from and args.restore_from is not None:
        root, extension = args.restore_from.strip().split(".")
        D1pth = root + "_D1." + extension
        D2pth = root + "_D2." + extension
        saved_state_dict = torch.load(D1pth)
        model_D1.load_state_dict(saved_state_dict)
        saved_state_dict = torch.load(D2pth)
        model_D2.load_state_dict(saved_state_dict)
        start_epoch = int(re.findall(r'[\d]+', root)[-1])
        print("Loaded state dict for models D1 and D2")

    model_D1.train()
    # model_D1.cuda(gpu)
    model_D2.train()
    # model_D2.cuda(gpu)

    model_D1 = model_D1.cuda(device=gpu)
    model_D2 = model_D2.cuda(device=gpu)

    if args.num_gpus > 0 or torch.cuda.device_count() > 0:
        model_D1 = DistributedDataParallel(model_D1,
                                           device_ids=[gpu],
                                           find_unused_parameters=True)
        model_D2 = DistributedDataParallel(model_D2,
                                           device_ids=[gpu],
                                           find_unused_parameters=True)

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    train_dataset = SyntheticSmokeTrain(args={},
                                        dataset_limit=args.num_steps *
                                        args.iter_size * args.batch_size,
                                        image_shape=input_size,
                                        dataset_mean=IMG_MEAN)

    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=args.world_size,
                                       rank=rank,
                                       shuffle=True)
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True,
                                  sampler=train_sampler)

    # trainloader = data.DataLoader(
    #     GTA5DataSet(args.data_dir, args.data_list, max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                 crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN),
    #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainloader_iter = enumerate(trainloader)
    print("Length of train dataloader: ", len(trainloader))
    target_dataset = SimpleSmokeVal(args={},
                                    image_size=input_size_target,
                                    dataset_mean=IMG_MEAN)
    target_sampler = DistributedSampler(target_dataset,
                                        num_replicas=args.world_size,
                                        rank=rank,
                                        shuffle=True)
    targetloader = data.DataLoader(target_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   pin_memory=True,
                                   sampler=target_sampler)

    # targetloader = data.DataLoader(cityscapesDataSet(args.data_dir_target, args.data_list_target,
    #                                                  max_iters=args.num_steps * args.iter_size * args.batch_size,
    #                                                  crop_size=input_size_target,
    #                                                  scale=False, mirror=args.random_mirror, mean=IMG_MEAN,
    #                                                  set=args.set),
    #                                batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
    #                                pin_memory=True)

    targetloader_iter = enumerate(targetloader)
    print("Length of train dataloader: ", len(targetloader))
    # implement model.optim_parameters(args) to handle different models' lr setting

    optimizer = optim.SGD(model.module.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    optimizer_D1 = optim.Adam(model_D1.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D1.zero_grad()

    optimizer_D2 = optim.Adam(model_D2.parameters(),
                              lr=args.learning_rate_D,
                              betas=(0.9, 0.99))
    optimizer_D2.zero_grad()

    if args.gan == 'Vanilla':
        bce_loss = torch.nn.BCEWithLogitsLoss()
    elif args.gan == 'LS':
        bce_loss = torch.nn.MSELoss()

    interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear')
    interp_target = nn.Upsample(size=(input_size_target[1],
                                      input_size_target[0]),
                                mode='bilinear')

    # labels for adversarial training
    source_label = 0
    target_label = 1

    for i_iter in range(start_epoch, args.num_steps):

        loss_seg_value1 = 0
        loss_adv_target_value1 = 0
        loss_D_value1 = 0

        loss_seg_value2 = 0
        loss_adv_target_value2 = 0
        loss_D_value2 = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)

        optimizer_D1.zero_grad()
        optimizer_D2.zero_grad()
        adjust_learning_rate_D(optimizer_D1, i_iter)
        adjust_learning_rate_D(optimizer_D2, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D1.parameters():
                param.requires_grad = False

            for param in model_D2.parameters():
                param.requires_grad = False

            # train with source
            # try:
            _, batch = next(trainloader_iter)  #.next()
            # except StopIteration:
            # trainloader = data.DataLoader(
            #     SyntheticSmokeTrain(args={}, dataset_limit=args.num_steps * args.iter_size * args.batch_size,
            #                 image_shape=input_size, dataset_mean=IMG_MEAN),
            #     batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
            # trainloader_iter = iter(trainloader)
            # _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cuda(gpu)
            # print("Shape of labels", labels.shape)
            # print("Are labels all zero? ")
            # for i in range(labels.shape[0]):
            #     print("{}: All zero? {}".format(i, torch.all(labels[i]==0)))
            #     print("{}: All 255? {}".format(i, torch.all(labels[i]==255)))
            #     print("{}: Mean = {}".format(i, torch.mean(labels[i])))

            pred1, pred2 = model(images)
            # print("Pred1 and Pred2 original size: {}, {}".format(pred1.shape, pred2.shape))
            pred1 = interp(pred1)
            pred2 = interp(pred2)
            # print("Pred1 and Pred2 upsampled size: {}, {}".format(pred1.shape, pred2.shape))
            # for pred, name in zip([pred1, pred2], ['pred1', 'pred2']):
            #     print(name)
            #     for i in range(pred.shape[0]):
            #         print("{}: All zero? {}".format(i, torch.all(pred[i]==0)))
            #         print("{}: All 255? {}".format(i, torch.all(pred[i]==255)))
            #         print("{}: Mean = {}".format(i, torch.mean(pred[i])))

            loss_seg1 = loss_calc(pred1, labels, gpu, criterion)
            loss_seg2 = loss_calc(pred2, labels, gpu, criterion)
            loss = loss_seg2 + args.lambda_seg * loss_seg1

            # proper normalization
            loss = loss / args.iter_size
            loss.backward()
            # print("Seg1 loss: ",loss_seg1, args.iter_size)
            # print("Seg2 loss: ",loss_seg2, args.iter_size)

            loss_seg_value1 += loss_seg1.data.cpu().item() / args.iter_size
            loss_seg_value2 += loss_seg2.data.cpu().item() / args.iter_size

            # train with target
            # try:
            _, batch = next(targetloader_iter)  #.next()
            # except StopIteration:
            #     targetloader = data.DataLoader(
            #         SimpleSmokeVal(args = {}, image_size=input_size_target, dataset_mean=IMG_MEAN),
            #                         batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,
            #                         pin_memory=True)
            #     targetloader_iter = iter(targetloader)
            #     _, batch = next(targetloader_iter)

            images, _, _ = batch
            images = Variable(images).cuda(gpu)

            pred_target1, pred_target2 = model(images)
            pred_target1 = interp_target(pred_target1)
            pred_target2 = interp_target(pred_target2)

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_adv_target1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(gpu))

            loss_adv_target2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(gpu))

            loss = args.lambda_adv_target1 * loss_adv_target1 + args.lambda_adv_target2 * loss_adv_target2
            loss = loss / args.iter_size
            loss.backward()
            loss_adv_target_value1 += loss_adv_target1.data.cpu().item(
            ) / args.iter_size
            loss_adv_target_value2 += loss_adv_target2.data.cpu().item(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D1.parameters():
                param.requires_grad = True

            for param in model_D2.parameters():
                param.requires_grad = True

            # train with source
            pred1 = pred1.detach()
            pred2 = pred2.detach()

            D_out1 = model_D1(F.softmax(pred1, dim=1))
            D_out2 = model_D2(F.softmax(pred2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(source_label)).cuda(gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(source_label)).cuda(gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().item()
            loss_D_value2 += loss_D2.data.cpu().item()

            # train with target
            pred_target1 = pred_target1.detach()
            pred_target2 = pred_target2.detach()

            D_out1 = model_D1(F.softmax(pred_target1, dim=1))
            D_out2 = model_D2(F.softmax(pred_target2, dim=1))

            loss_D1 = bce_loss(
                D_out1,
                Variable(
                    torch.FloatTensor(
                        D_out1.data.size()).fill_(target_label)).cuda(gpu))

            loss_D2 = bce_loss(
                D_out2,
                Variable(
                    torch.FloatTensor(
                        D_out2.data.size()).fill_(target_label)).cuda(gpu))

            loss_D1 = loss_D1 / args.iter_size / 2
            loss_D2 = loss_D2 / args.iter_size / 2

            loss_D1.backward()
            loss_D2.backward()

            loss_D_value1 += loss_D1.data.cpu().item()
            loss_D_value2 += loss_D2.data.cpu().item()

        optimizer.step()
        optimizer_D1.step()
        optimizer_D2.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f} loss_seg2 = {3:.3f} loss_adv1 = {4:.3f}, loss_adv2 = {5:.3f} loss_D1 = {6:.3f} loss_D2 = {7:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value1, loss_seg_value2,
                    loss_adv_target_value1, loss_adv_target_value2,
                    loss_D_value1, loss_D_value2))
        writer.add_scalar(f'loss/train/segmentation/1', loss_seg_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/segmentation/2', loss_seg_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/1', loss_adv_target_value1,
                          i_iter)
        writer.add_scalar(f'loss/train/adversarial/2', loss_adv_target_value2,
                          i_iter)
        writer.add_scalar(f'loss/train/domain/1', loss_D_value1, i_iter)
        writer.add_scalar(f'loss/train/domain/2', loss_D_value2, i_iter)

        if i_iter >= args.num_steps_stop - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir, 'smoke_cross_entropy_multigpu_' +
                    str(args.num_steps_stop) + '_D2.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '.pth'))
            torch.save(
                model_D1.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '_D1.pth'))
            torch.save(
                model_D2.state_dict(),
                osp.join(
                    args.snapshot_dir,
                    'smoke_cross_entropy_multigpu_' + str(i_iter) + '_D2.pth'))
        writer.flush()
예제 #5
0
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for Transformers.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None
    optimizers: Tuple[torch.optim.Optimizer,
                      torch.optim.lr_scheduler.LambdaLR] = None
    global_step: Optional[int] = None
    epoch: Optional[float] = None

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        neptune,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[torch.optim.Optimizer,
                          torch.optim.lr_scheduler.LambdaLR] = None,
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.

        Args:
            prediction_loss_only:
                (Optional) in evaluation and prediction, only return the loss
        """
        self.model = model.to(args.device)
        self.args = args
        self.neptune = neptune
        if data_collator is not None:
            self.data_collator = data_collator
        else:
            self.data_collator = DefaultDataCollator()
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizers = optimizers
        if tb_writer is not None:
            self.tb_writer = tb_writer
        elif is_tensorboard_available() and self.is_world_master():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self._setup_wandb()
        else:
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_master():
            os.makedirs(self.args.output_dir, exist_ok=True)
        if is_torch_tpu_available():
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        if is_torch_tpu_available():
            train_sampler = get_tpu_sampler(self.train_dataset)
        else:
            train_sampler = (SequentialSampler(self.train_dataset)
                             if self.args.local_rank == -1 else
                             SequentialDistributedSampler(self.train_dataset))

        data_loader = DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator.collate_batch,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_eval_dataloader(self,
                            eval_dataset: Optional[Dataset] = None
                            ) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if is_torch_tpu_available():
            sampler = SequentialDistributedSampler(
                eval_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(eval_dataset)
        else:
            sampler = SequentialSampler(eval_dataset)

        data_loader = DataLoader(
            eval_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator.collate_batch,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        # We use the same batch_size as for eval.
        if is_torch_tpu_available():
            sampler = SequentialDistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(test_dataset)
        else:
            sampler = SequentialSampler(test_dataset)

        data_loader = DataLoader(
            test_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator.collate_batch,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_optimizers(
        self, num_training_steps: int
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well.
        If you want to use something else, you can pass a tuple in the Trainer's init,
        or override this method in a subclass.
        """
        if self.optimizers is not None:
            return self.optimizers
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.args.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.args.learning_rate,
                          eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.args.warmup_steps,
            num_training_steps=num_training_steps)
        return optimizer, scheduler

    def _setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can override this method to customize the setup if needed.  Find more information at https://docs.wandb.com/huggingface
        You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
        """
        if self.is_world_master():
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"),
                       config=vars(self.args))
            # keep track of model topology and gradients
            if os.getenv("WANDB_WATCH") != "false":
                wandb.watch(self.model,
                            log=os.getenv("WANDB_WATCH", "gradients"),
                            log_freq=max(100, self.args.logging_steps))

    def num_examples(self, dataloader: DataLoader) -> int:
        """
        Helper to get num of examples from a DataLoader, by accessing its Dataset.
        """
        return len(dataloader.dataset)

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (model_path is not None
                and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
                and os.path.isfile(os.path.join(model_path, "scheduler.pt"))):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"),
                           map_location=self.args.device))
            scheduler.load_state_dict(
                torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(),
                                       metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size(
            )
        else:
            total_train_batch_size = (self.args.train_batch_size *
                                      self.args.gradient_accumulation_steps *
                                      (torch.distributed.get_world_size()
                                       if self.args.local_rank != -1 else 1))
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d",
                    self.args.per_device_train_batch_size)
        logger.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d",
                    self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            self.global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        accumulation_loss = 0.0
        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(epochs_trained,
                                int(num_train_epochs),
                                desc="Epoch",
                                disable=not self.is_local_master())
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(
                    train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(
                    train_dataloader,
                    [self.args.device]).per_device_loader(self.args.device)
                epoch_iterator = tqdm(parallel_loader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())
            else:
                epoch_iterator = tqdm(train_dataloader,
                                      desc="Iteration",
                                      disable=not self.is_local_master())

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                loss = self._training_step(model, inputs, optimizer)
                accumulation_loss += loss
                tr_loss += loss

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer),
                            self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(train_dataloader)

                    if is_torch_tpu_available():
                        if xm.get_ordinal() == 0:
                            self.neptune.log_metric('loss', self.global_step,
                                                    accumulation_loss)
                    else:
                        self.neptune.log_metric('loss', self.global_step,
                                                accumulation_loss)

                    accumulation_loss = 0.0

                    if (self.args.logging_steps > 0
                            and self.global_step % self.args.logging_steps
                            == 0) or (self.global_step == 1
                                      and self.args.logging_first_step):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss -
                                        logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >=
                            version.parse("1.4") else scheduler.get_lr()[0])
                        logging_loss = tr_loss

                        self._log(logs)

                        if self.args.evaluate_during_training:
                            self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(
                            self.args.output_dir,
                            f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(),
                                    os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(),
                                    os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(
                                optimizer.state_dict(),
                                os.path.join(output_dir, "optimizer.pt"))
                            torch.save(
                                scheduler.state_dict(),
                                os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(self.global_step, tr_loss / self.global_step)

    def _log(self,
             logs: Dict[str, float],
             iterator: Optional[tqdm] = None) -> None:
        if self.epoch is not None:
            logs["epoch"] = self.epoch
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_master():
                wandb.log(logs, step=self.global_step)
        output = json.dumps({**logs, **{"step": self.global_step}})
        if iterator is not None:
            iterator.write(output)
        else:
            print(output)

    def _training_step(self, model: nn.Module, inputs: Dict[str, torch.Tensor],
                       optimizer: torch.optim.Optimizer) -> float:
        model.train()
        for k, v in inputs.items():
            inputs[k] = v.to(self.args.device)

        outputs = model(**inputs)
        loss = outputs[
            0]  # model outputs are always tuple in transformers (see doc)

        if self.args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

    def is_local_master(self) -> bool:
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank(
            ) == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Saving best-practices: if you use default names for the model,
        you can reload it using from_pretrained().

        Will only save from the world_master process (unless in TPUs).
        """

        if is_torch_tpu_available():
            self._save_tpu(output_dir)
        elif self.is_world_master():
            self._save(output_dir)

    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir,
                                               "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError(
                "Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError(
                "Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self,
                            checkpoint_prefix=PREFIX_CHECKPOINT_DIR,
                            use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [
            str(x)
            for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")
        ]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append(
                    (os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append(
                        (int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [
            checkpoint[1] for checkpoint in checkpoints_sorted
        ]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(
            0,
            len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:
                                                       number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info(
                "Deleting older checkpoint [{}] due to args.save_total_limit".
                format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        prediction_loss_only: Optional[bool] = None,
    ) -> Dict[str, float]:
        """
        Run evaluation and return metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent.

        Args:
            eval_dataset: (Optional) Pass a dataset if you wish to override
            the one on the instance.
        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader,
                                       description="Evaluation")

        self._log(output.metrics)

        if self.args.tpu_metrics_debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)

        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.
        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(
                dataloader,
                [self.args.device]).per_device_loader(self.args.device)

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(
                inputs.get(k) is not None
                for k in ["labels", "lm_labels", "masked_lm_labels"])

            for k, v in inputs.items():
                inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                if preds is None:
                    preds = logits.detach()
                else:
                    preds = torch.cat((preds, logits.detach()), dim=0)
                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach()
                    else:
                        label_ids = torch.cat(
                            (label_ids, inputs["labels"].detach()), dim=0)

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(
                    preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(
                    label_ids,
                    num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids,
                                           torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds,
                                label_ids=label_ids,
                                metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor,
                           num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [
            tensor.clone() for _ in range(torch.distributed.get_world_size())
        ]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output
예제 #6
0
def main():
    # load expert data
    print(args.data_set_path)
    dataset = ExpertDataSet(args.data_set_path)
    data_loader = data.DataLoader(
        dataset=dataset,
        batch_size=args.expert_batch_size,
        shuffle=True,
        num_workers=0
    )
    p_state_sizes = [args.n_state, args.n_state, args.n_state, args.n_state + args.n_action]
    p_action_sizes = [args.n_onehot_action, args.n_multihot_action, args.n_continuous_action, args.n_continuous_state - args.n_continuous_action]
    
    d_state_sizes = [args.n_state, args.n_state, args.n_state, args.n_state]
    d_action_sizes = [args.n_onehot_action, args.n_multihot_action, args.n_continuous_action, args.n_continuous_state - args.n_continuous_action]

    policy = MultiPolicy(p_state_sizes, p_action_sizes, onehot_action_sections, onehot_state_sections, state_0 = dataset.state)
    discriminator = MultiDiscriminator(d_state_sizes, d_action_sizes)
    discriminator_criterion = nn.BCELoss()
    if write_scalar:
        writer = SummaryWriter(log_dir='runs/' + model_name)

    # load net  models
    if load_model:
        discriminator = torch.load('./model_pkl/multi_policy/D_' + model_name + '.pkl')
        policy = torch.load('./model_pkl/multi_policy/P_' + model_name + '.pkl')
    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        policy.train()
        discriminator.train()
        start_time = time.time()
        memory, n_trajs = policy.collect_samples(batch_size=args.sample_batch_size)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        gen_state = torch.cat(batch.state, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()
        gen_action = torch.cat(batch.action, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()
        gen_next_state = torch.cat(batch.next_state, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()
        old_log_prob = torch.cat(batch.old_log_prob, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()
        mask = torch.cat(batch.mask, dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()
        
        gen_d_state, gen_d_action = make_d_inputs(gen_state, gen_action, gen_next_state)
        if ep % 1 == 0:
        # if (d_slow_flag and ep % 50 == 0) or (not d_slow_flag and ep % 1 == 0):
            d_loss = torch.empty(0, device=device)
            p_loss = torch.empty(0, device=device)
            v_loss = torch.empty(0, device=device)
            gen_r = torch.empty(0, device=device)
            expert_r = torch.empty(0, device=device)
            for expert_state_batch, expert_action_batch, expert_next_state_batch in data_loader:
                expert_d_state, expert_d_action = make_d_inputs(expert_state_batch.to(device), expert_action_batch.to(device), expert_next_state_batch.to(device))
                gen_r = discriminator(gen_d_state, gen_d_action)
                expert_r = discriminator(expert_d_state, expert_d_action)

                discriminator.optimizer.zero_grad()
                d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \
                            discriminator_criterion(expert_r,torch.ones(expert_r.shape, device=device))
                variance = 0.5 * torch.var(gen_r.to(device)) + 0.5 * torch.var(expert_r.to(device))
                total_d_loss = d_loss - 10 * variance
                d_loss.backward()
                # total_d_loss.backward()
                discriminator.optimizer.step()
            if write_scalar:
                writer.add_scalar('loss/d_loss', d_loss, ep)
                writer.add_scalar('loss/total_d_loss', total_d_loss, ep)
                writer.add_scalar('loss/variance', 10 * variance, ep)

        if ep % 1 == 0:
            # update PPO
            gen_r = discriminator(gen_d_state, gen_d_action)
            optimize_iter_num = int(math.ceil(gen_state.shape[0] / args.ppo_mini_batch_size))
            for ppo_ep in range(args.ppo_optim_epoch):
                for i in range(optimize_iter_num):
                    num += 1
                    index = slice(i * args.ppo_mini_batch_size, min((i + 1) * args.ppo_mini_batch_size, gen_state.shape[0]))
                    gen_state_batch, gen_action_batch, gen_next_state_batch, old_log_prob_batch, mask_batch, gen_r_batch = \
                        gen_state[index], gen_action[index], gen_next_state[index], old_log_prob[index], mask[index], gen_r[index]
                    v_loss, p_loss = ppo_step(policy,
                                            gen_state_batch,
                                            gen_action_batch, 
                                            gen_next_state_batch,
                                            gen_r_batch, old_log_prob_batch,
                                            mask_batch, args.ppo_clip_epsilon)
        policy.eval()
        discriminator.eval()
        gen_d_state, gen_d_action = make_d_inputs(gen_state, gen_action, gen_next_state)
        expert_d_state, expert_d_action = make_d_inputs(expert_state_batch.to(device), expert_action_batch.to(device), expert_next_state_batch.to(device))
        gen_r = discriminator(gen_d_state, gen_d_action)
        expert_r = discriminator(expert_d_state, expert_d_action)
        gen_r_noise = gen_r.mean(dim=0)
        expert_r_noise = expert_r.mean(dim=0)
        gen_r = discriminator(gen_d_state, gen_d_action, noise=False)
        expert_r = discriminator(expert_d_state, expert_d_action, noise=False)
        if write_scalar:
            writer.add_scalar('gen_r_accurate/onehot', gen_r.mean(dim=0)[0], ep)
            writer.add_scalar('gen_r_accurate/multihot', gen_r.mean(dim=0)[1], ep)
            writer.add_scalar('gen_r_accurate/continuous', gen_r.mean(dim=0)[2], ep)
            writer.add_scalar('gen_r_accurate/next_state', gen_r.mean(dim=0)[3], ep)
            writer.add_scalar('expert_r_accurate/onehot', expert_r.mean(dim=0)[0], ep)
            writer.add_scalar('expert_r_accurate/multihot', expert_r.mean(dim=0)[1], ep)
            writer.add_scalar('expert_r_accurate/continuous', expert_r.mean(dim=0)[2], ep)
            writer.add_scalar('expert_r_accurate/next_state', expert_r.mean(dim=0)[3], ep)
            writer.add_scalar('gen_r_with_noise/onehot', gen_r_noise[0], ep)
            writer.add_scalar('gen_r_with_noise/multihot', gen_r_noise[1], ep)
            writer.add_scalar('gen_r_with_noise/continuous', gen_r_noise[2], ep)
            writer.add_scalar('gen_r_with_noise/next_state', gen_r_noise[3], ep)
            writer.add_scalar('expert_r_with_noise/onehot', expert_r_noise[0], ep)
            writer.add_scalar('expert_r_with_noise/multihot', expert_r_noise[1], ep)
            writer.add_scalar('expert_r_with_noise/continuous', expert_r_noise[2], ep)
            writer.add_scalar('expert_r_with_noise/next_state', expert_r_noise[3], ep)
            writer.add_scalar('total/gen_r_accurate', gen_r.mean(), ep)
            writer.add_scalar('total/expert_r_accurate', expert_r.mean(), ep)
            writer.add_scalar('total/gen_r_with_noise', gen_r_noise.mean(), ep)
            writer.add_scalar('total/expert_r_with_noise', expert_r_noise.mean(), ep)
        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('gen_r_noise:', gen_r_noise)
        print('expert_r_noise:', expert_r_noise)
        print('gen_r:', gen_r.mean(dim=0))
        print('expert_r:', expert_r.mean(dim=0))
        print('d_loss', d_loss.item())
        # save models
        if model_name is not None:
            torch.save(discriminator, './model_pkl/multi_policy/D_' + model_name + '.pkl')
            torch.save(policy, './model_pkl/multi_policy/P_' + model_name + '.pkl')
        memory.clear_memory()
        for i, batch in enumerate(tqdm(test_loader)):
            inp, gt, gt_flag = process_valBatch(batch)

            inp = Variable(inp).float().to(device)

            with autocast(enabled=args.amp):
                out = model(inp)
            out = out.type(inp.dtype) 

            for b in range(len(batch['filename'])):
                metrics = saver.CalcNSave(out[b,...].detach().cpu().squeeze(), inp[b,...].detach().cpu().squeeze(), gt[b,...].squeeze().float() if gt_flag[b] else None, batch['filename'][b].split(".")[0])

                if metrics is not None:
                    metrics['file'] = batch['filename']
                    test_metrics.append(metrics)

                    ssim = round(metrics['SSIMOut'],4)
                    test_ssim.append(ssim)
                    runningSSIM.append(ssim)
                    logging.info('[%d/%d] Test SSIM: %.4f' % (i, len(test_loader), ssim))
                    #For tensorboard
                    if i % args.log_freq == 0:
                        niter = len(test_loader)+i
                        tb_writer.add_scalar('Test/SSIM', median(runningSSIM), niter)
                        runningSSIM = []
    
    if len(test_metrics) > 0:
        print("Avg SSIM: "+str(median(test_ssim)))
        df = pd.DataFrame.from_dict(test_metrics)
        df.to_csv(os.path.join(args.save_path, 'Results.csv'), index=False)
예제 #8
0
class Logger:
    """
    A general-purpose logger.

    Makes it easy to save diagnostics, hyper-parameter configurations, the
    state of a training run, and the trained model.
    """
    def __init__(
            self,
            log_dir,
            output_fname='progress.csv',
            debug: bool = False,
            exp_name=None,
            level: int = 1,  # verbosity level
            use_tensor_board=True,
            verbose=True):
        """
        Initialize a Logger.

        Args:
            log_dir (string): A directory for saving results to. If
                ``None``, defaults to a temp directory of the form
                ``/tmp/experiments/somerandomnumber``.

            output_fname (string): Name for the tab-separated-value file
                containing metrics logged throughout a training run.
                Defaults to ``progress.txt``.

            exp_name (string): Experiment name. If you run multiple training
                runs and give them all the same ``exp_name``, the plotter
                will know to group them. (Use case: if you run the same
                hyperparameter configuration with multiple random seeds, you
                should give them all the same ``exp_name``.)
        """
        self.log_dir = log_dir
        self.debug = debug if proc_id() == 0 else False
        self.level = level
        # only the MPI root process is allowed to print information to console
        self.verbose = verbose if proc_id() == 0 else False

        if proc_id() == 0:
            os.makedirs(self.log_dir, exist_ok=True)
            self.output_file = open(osp.join(self.log_dir, output_fname), 'w')
            atexit.register(self.output_file.close)
            print(
                colorize(f"Logging data to {self.output_file.name}",
                         'cyan',
                         bold=True))
        else:
            self.output_file = None

        self.epoch = 0
        self.first_row = True
        self.log_headers = []
        self.log_current_row = {}
        self.exp_name = exp_name
        self.torch_saver_elements = None

        # Setup tensor board logging if enabled and MPI root process
        self.summary_writer = SummaryWriter(os.path.join(self.log_dir, 'tb')) \
            if use_tensor_board and proc_id() == 0 else None

    def close(self):
        """Close opened output files immediately after training in order to
        avoid number of open files overflow. Avoids the following error:
        OSError: [Errno 24] Too many open files
        """
        if proc_id() == 0:
            self.output_file.close()

    def debug(self, msg, color='yellow'):
        """Print a colorized message to stdout."""
        if self.debug:
            print(colorize(msg, color, bold=False))

    def log(self, msg, color='green'):
        """Print a colorized message to stdout."""
        if self.verbose and self.level > 0:
            print(colorize(msg, color, bold=False))

    def log_tabular(self, key, val):
        """
        Log a value of some diagnostic.

        Call this only once for each diagnostic quantity, each iteration.
        After using ``log_tabular`` to store values for each diagnostic,
        make sure to call ``dump_tabular`` to write them out to file and
        stdout (otherwise they will not get saved anywhere).
        """
        if self.first_row:
            self.log_headers.append(key)
        else:
            assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration" % key
        assert key not in self.log_current_row, "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key
        self.log_current_row[key] = val

    def save_config(self, config):
        """
        Log an experiment configuration.

        Call this once at the top of your experiment, passing in all important
        config vars as a dict. This will serialize the config to JSON, while
        handling anything which can't be serialized in a graceful way (writing
        as informative a string as possible).

        Example use:

        .. code-block:: python

            logger = EpochLogger(**logger_kwargs)
            logger.save_config(locals())
        """
        if proc_id() == 0:  # only root process logs configurations
            config_json = convert_json(config)
            if self.exp_name is not None:
                config_json['exp_name'] = self.exp_name

            output = json.dumps(config_json,
                                separators=(',', ':\t'),
                                indent=4,
                                sort_keys=True)
            if self.verbose and self.level > 0:
                print(colorize('Run with config:', color='yellow', bold=True))
                print(output)
            with open(osp.join(self.log_dir, "config.json"), 'w') as out:
                out.write(output)

    def save_state(self, state_dict, itr=None):
        """
        Saves the state of an experiment.

        To be clear: this is about saving *state*, not logging diagnostics.
        All diagnostic logging is separate from this function. This function
        will save whatever is in ``state_dict``---usually just a copy of the
        environment---and the most recent parameters for the model you
        previously set up saving for with ``setup_tf_saver``.

        Call with any frequency you prefer. If you only want to maintain a
        single state and overwrite it at each call with the most recent
        version, leave ``itr=None``. If you want to keep all of the states you
        save, provide unique (increasing) values for 'itr'.

        Args:
            state_dict (dict): Dictionary containing essential elements to
                describe the current state of training.

            itr: An int, or None. Current iteration of training.
        """
        if proc_id() == 0:
            fname = 'state.pkl' if itr is None else 'state%d.pkl' % itr
            try:
                joblib.dump(state_dict, osp.join(self.log_dir, fname))
            except:
                self.log('Warning: could not pickle state_dict.', color='red')
            if hasattr(self, 'torch_saver_elements'):
                self.torch_save(itr)

    def setup_torch_saver(self, what_to_save):
        """
        Set up easy model saving for a single PyTorch model.

        Because PyTorch saving and loading is especially painless, this is
        very minimal; we just need references to whatever we would like to
        pickle. This is integrated into the logger because the logger
        knows where the user would like to save information about this
        training run.

        Args:
            what_to_save: Any PyTorch model or serializable object containing
                PyTorch models.
        """
        self.torch_saver_elements = what_to_save

    def torch_save(self, itr=None):
        """
        Saves the PyTorch model (or models).
        """
        if proc_id() == 0:
            self.log('Save model to disk...')
            assert self.torch_saver_elements is not None,\
                "First have to setup saving with self.setup_torch_saver"
            fpath = 'torch_save'
            fpath = osp.join(self.log_dir, fpath)
            fname = 'model' + ('%d' % itr if itr is not None else '') + '.pt'
            fname = osp.join(fpath, fname)
            os.makedirs(fpath, exist_ok=True)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                # We are using a non-recommended way of saving PyTorch models,
                # by pickling whole objects (which are dependent on the exact
                # directory structure at the time of saving) as opposed to
                # just saving network weights. This works sufficiently well
                # for the purposes of Spinning Up, but you may want to do
                # something different for your personal PyTorch project.
                # We use a catch_warnings() context to avoid the warnings about
                # not being able to save the source code.
                torch.save(self.torch_saver_elements, fname)
            torch.save(self.torch_saver_elements.state_dict(), fname)
            self.log('Done.')

    def dump_tabular(self) -> None:
        """
        Write all of the diagnostics from the current iteration.

        Writes both to stdout, and to the output file.
        """
        if proc_id() == 0:
            vals = list()
            self.epoch += 1
            # Print formatted information into console
            key_lens = [len(key) for key in self.log_headers]
            max_key_len = max(15, max(key_lens))
            keystr = '%' + '%d' % max_key_len
            fmt = "| " + keystr + "s | %15s |"
            n_slashes = 22 + max_key_len
            print("-" * n_slashes) if self.verbose and self.level > 0 else None
            for key in self.log_headers:
                val = self.log_current_row.get(key, "")
                valstr = "%8.3g" % val if hasattr(val, "__float__") else val
                if self.verbose and self.level > 0:
                    print(fmt % (key, valstr))
                vals.append(val)
            if self.verbose and self.level > 0:
                print("-" * n_slashes, flush=True)

            # Write into the output file (can be any text file format, e.g. CSV)
            if self.output_file is not None:
                if self.first_row:
                    self.output_file.write(",".join(self.log_headers) + "\n")
                self.output_file.write(",".join(map(str, vals)) + "\n")
                self.output_file.flush()

            if self.summary_writer is not None:
                [
                    self.summary_writer.add_scalar(k,
                                                   v,
                                                   global_step=self.epoch)
                    for (k, v) in zip(self.log_headers, vals)
                ]
                # Flushes the event file to disk. Call this method to make sure
                # that all pending events have been written to disk.
                self.summary_writer.flush()

        # free logged information in all processes...
        self.log_current_row.clear()
        self.first_row = False
예제 #9
0
class SupervisedNetwork:
    def __init__(self, config):
        now = datetime.now()
        date_time = now.strftime("%m-%d-%H-%M-%S")
        self.tensorboard_writer = SummaryWriter(log_dir='runs/Supervised-' + date_time)
        
        self.base_path = config['base_path']
        self.stamp = config['stamp']
        self.meta_epochs = config['num_meta_epochs']
        self.early_stopping = config['early_stopping']
        self.stopping_threshold = config.get('stopping_threshold', 1e-3)

        if 'seq' in config['meta_model']:
            self.model = SeqSupervisedNetwork(config)

        logger.info('Supervised network instantiated')

    def training(self, train_dataloader, val_dataloader, tags):
        best_loss = float('inf')
        best_f1 = 0
        patience = 0
        model_path = os.path.join(self.base_path, 'saved_models', 'SupervisedLearner-{}.h5'.format(self.stamp))
        classifier_path = os.path.join(self.base_path, 'saved_models', 'SupervisedClassifier-{}.h5'.format(self.stamp))
        logger.info('Model name: SupervisedLearner-{}.h5'.format(self.stamp))
        for epoch in range(self.meta_epochs):
            logger.info('Starting epoch {}/{}'.format(epoch + 1, self.meta_epochs))
            avg_loss, avg_accuracy, avg_precision, avg_recall, avg_f1 = self.model(train_dataloader, 
                                                                                   tags=tags, 
                                                                                   writer=self.tensorboard_writer)

            logger.info('Train epoch {}: Avg loss = {:.5f}, avg accuracy = {:.5f}, avg precision = {:.5f}, '
                        'avg recall = {:.5f}, avg F1 score = {:.5f}'.format(epoch + 1, avg_loss, avg_accuracy,
                                                                            avg_precision, avg_recall, avg_f1))

            self.tensorboard_writer.add_scalar('Loss/train', avg_loss, global_step=epoch + 1)
            self.tensorboard_writer.add_scalar('F1/train', avg_f1, global_step=epoch + 1)

            avg_loss, avg_accuracy, avg_precision, avg_recall, avg_f1 = self.model(val_dataloader, tags=tags, testing=True)

            logger.info('Val epoch {}: Avg loss = {:.5f}, avg accuracy = {:.5f}, avg precision = {:.5f}, '
                        'avg recall = {:.5f}, avg F1 score = {:.5f}'.format(epoch + 1, avg_loss, avg_accuracy,
                                                                            avg_precision, avg_recall, avg_f1))

            self.tensorboard_writer.add_scalar('Loss/val', avg_loss, global_step=epoch + 1)
            self.tensorboard_writer.add_scalar('F1/val', avg_f1, global_step=epoch + 1)

            if avg_f1 > best_f1 + self.stopping_threshold:
                patience = 0
                best_loss = avg_loss
                best_f1 = avg_f1
                logger.info('Saving the model since the F1 improved')
                torch.save(self.model.learner.state_dict(), model_path)
                torch.save(self.model.classifier.state_dict(), classifier_path)
                logger.info('')
            else:
                patience += 1
                logger.info('F1 did not improve')
                logger.info('')
                if patience == self.early_stopping:
                    break

            # Log params and grads into tensorboard
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    self.tensorboard_writer.add_histogram('Params/' + name, param.data.view(-1),
                                                     global_step=epoch + 1)
                    self.tensorboard_writer.add_histogram('Grads/' + name, param.grad.data.view(-1),
                                                     global_step=epoch + 1)

        self.model.learner.load_state_dict(torch.load(model_path))
        self.model.classifier.load_state_dict(torch.load(classifier_path))
        return best_f1

    def testing(self, test_dataloader, tags):
        logger.info('---------- Supervised testing starts here ----------')
        
        _, accuracy, precision, recall, f1_score = self.model(test_dataloader, tags=tags, testing=True)

        logger.info('Avg meta-testing metrics: Accuracy = {:.5f}, precision = {:.5f}, recall = {:.5f}, '
                    'F1 score = {:.5f}'.format(accuracy,
                                               precision,
                                               recall,
                                               f1_score))
        return f1_score
예제 #10
0
class TrainManager:
    """ Manages training loop, validations, learning rate scheduling
    and early stopping."""
    def __init__(self, model: SignModel, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger(model_dir=self.model_dir)
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # input
        self.feature_size = (sum(config["data"]["feature_size"]) if isinstance(
            config["data"]["feature_size"], list) else
                             config["data"]["feature_size"])
        self.feature_size_features = config["data"]["feature_size_cnn"]

        self.dataset_version = config["data"].get("version",
                                                  "phoenix_2014_trans")

        # model
        self.model = model
        self.txt_pad_index = self.model.txt_pad_index
        self.txt_bos_index = self.model.txt_bos_index
        self._log_parameters_list()
        # Check if we are doing only recognition or only translation or both
        self.do_recognition = (config["training"].get(
            "recognition_loss_weight", 1.0) > 0.0)
        self.do_translation = (config["training"].get(
            "translation_loss_weight", 1.0) > 0.0)

        # Get Recognition and Translation specific parameters
        if self.do_recognition:
            self._get_recognition_params(train_config=train_config)
        if self.do_translation:
            self._get_translation_params(train_config=train_config)

        # optimization
        self.last_best_lr = train_config.get("learning_rate", -1)
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)
        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())
        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 100)
        self.num_valid_log = train_config.get("num_valid_log", 5)
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in ["bleu", "chrf", "wer", "rouge"]:
            raise ValueError("Invalid setting for 'eval_metric': {}".format(
                self.eval_metric))
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # if we schedule after BLEU/chrf, we want to maximize it, else minimize
        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric
        if self.early_stopping_metric in [
                "ppl",
                "translation_loss",
                "recognition_loss",
        ]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in ["bleu", "chrf", "rouge"]:
                assert self.do_translation
                self.minimize_metric = False
            else:  # eval metric that has to get minimized (not yet implemented)
                self.minimize_metric = True
        else:
            raise ValueError(
                "Invalid setting for 'early_stopping_metric': {}".format(
                    self.early_stopping_metric))

        # data_augmentation parameters
        self.frame_subsampling_ratio = config["data"].get(
            "frame_subsampling_ratio", None)
        self.random_frame_subsampling = config["data"].get(
            "random_frame_subsampling", None)
        self.random_frame_masking_ratio = config["data"].get(
            "random_frame_masking_ratio", None)

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"],
        )

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ValueError("Invalid segmentation level': {}".format(
                self.level))

        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            if self.do_translation:
                self.translation_loss_function.cuda()
            if self.do_recognition:
                self.recognition_loss_function.cuda()

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_txt_tokens = 0
        self.total_gls_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        self.best_all_ckpt_scores = {}
        # comparison function for scores
        self.is_best = (
            lambda score: score < self.best_ckpt_score
            if self.minimize_metric else score > self.best_ckpt_score)

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            reset_best_ckpt = train_config.get("reset_best_ckpt", False)
            reset_scheduler = train_config.get("reset_scheduler", False)
            reset_optimizer = train_config.get("reset_optimizer", False)
            self.init_from_checkpoint(
                model_load_path,
                reset_best_ckpt=reset_best_ckpt,
                reset_scheduler=reset_scheduler,
                reset_optimizer=reset_optimizer,
            )

    def _get_recognition_params(self, train_config) -> None:
        # NOTE (Cihan): The blank label is the silence index in the gloss vocabulary.
        #   There is an assertion in the GlossVocabulary class's __init__.
        #   This is necessary to do TensorFlow decoding, as it is hardcoded
        #   Currently it is hardcoded as 0.
        self.gls_silence_token = self.model.gls_vocab.stoi[SIL_TOKEN]
        assert self.gls_silence_token == 0

        self.recognition_loss_function = torch.nn.CTCLoss(
            blank=self.gls_silence_token, zero_infinity=True)
        self.recognition_loss_weight = train_config.get(
            "recognition_loss_weight", 1.0)
        self.eval_recognition_beam_size = train_config.get(
            "eval_recognition_beam_size", 1)

    def _get_translation_params(self, train_config) -> None:
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.translation_loss_function = XentLoss(
            pad_index=self.txt_pad_index, smoothing=self.label_smoothing)
        self.translation_normalization_mode = train_config.get(
            "translation_normalization", "batch")
        if self.translation_normalization_mode not in ["batch", "tokens"]:
            raise ValueError("Invalid normalization {}.".format(
                self.translation_normalization_mode))
        self.translation_loss_weight = train_config.get(
            "translation_loss_weight", 1.0)
        self.eval_translation_beam_size = train_config.get(
            "eval_translation_beam_size", 1)
        self.eval_translation_beam_alpha = train_config.get(
            "eval_translation_beam_alpha", -1)
        self.translation_max_output_length = train_config.get(
            "translation_max_output_length", None)

    def _save_checkpoint(self) -> None:
        """
        Save the model's current parameters and the training state to a
        checkpoint.

        The training state contains the total number of training steps,
        the total number of training tokens,
        the best checkpoint score and iteration so far,
        and optimizer and scheduler states.

        """
        model_path = "{}/{}.ckpt".format(self.model_dir, self.steps)
        state = {
            "steps":
            self.steps,
            "total_txt_tokens":
            self.total_txt_tokens if self.do_translation else 0,
            "total_gls_tokens":
            self.total_gls_tokens if self.do_recognition else 0,
            "best_ckpt_score":
            self.best_ckpt_score,
            "best_all_ckpt_scores":
            self.best_all_ckpt_scores,
            "best_ckpt_iteration":
            self.best_ckpt_iteration,
            "model_state":
            self.model.state_dict(),
            "optimizer_state":
            self.optimizer.state_dict(),
            "scheduler_state":
            self.scheduler.state_dict()
            if self.scheduler is not None else None,
        }
        torch.save(state, model_path)
        if self.ckpt_queue.full():
            to_delete = self.ckpt_queue.get()  # delete oldest ckpt
            try:
                os.remove(to_delete)
            except FileNotFoundError:
                self.logger.warning(
                    "Wanted to delete old checkpoint %s but "
                    "file does not exist.",
                    to_delete,
                )

        self.ckpt_queue.put(model_path)

        # create/modify symbolic link for best checkpoint
        symlink_update("{}.ckpt".format(self.steps),
                       "{}/best.ckpt".format(self.model_dir))

    def init_from_checkpoint(
        self,
        path: str,
        reset_best_ckpt: bool = False,
        reset_scheduler: bool = False,
        reset_optimizer: bool = False,
    ) -> None:
        """
        Initialize the trainer from a given checkpoint file.

        This checkpoint file contains not only model parameters, but also
        scheduler and optimizer states, see `self._save_checkpoint`.

        :param path: path to checkpoint
        :param reset_best_ckpt: reset tracking of the best checkpoint,
                                use for domain adaptation with a new dev
                                set or when using a new metric for fine-tuning.
        :param reset_scheduler: reset the learning rate scheduler, and do not
                                use the one stored in the checkpoint.
        :param reset_optimizer: reset the optimizer, and do not use the one
                                stored in the checkpoint.
        """
        model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda)

        # restore model and optimizer parameters
        self.model.load_state_dict(model_checkpoint["model_state"])

        if not reset_optimizer:
            self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])
        else:
            self.logger.info("Reset optimizer.")

        if not reset_scheduler:
            if (model_checkpoint["scheduler_state"] is not None
                    and self.scheduler is not None):
                self.scheduler.load_state_dict(
                    model_checkpoint["scheduler_state"])
        else:
            self.logger.info("Reset scheduler.")

        # restore counts
        self.steps = model_checkpoint["steps"]
        self.total_txt_tokens = model_checkpoint["total_txt_tokens"]
        self.total_gls_tokens = model_checkpoint["total_gls_tokens"]

        if not reset_best_ckpt:
            self.best_ckpt_score = model_checkpoint["best_ckpt_score"]
            self.best_all_ckpt_scores = model_checkpoint[
                "best_all_ckpt_scores"]
            self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"]
        else:
            self.logger.info("Reset tracking of the best checkpoint.")

        # move parameters to cuda
        if self.use_cuda:
            self.model.cuda()

    def train_and_validate(self, train_data: Dataset,
                           valid_data: Dataset) -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(
            train_data,
            batch_size=self.batch_size,
            batch_type=self.batch_type,
            train=True,
            shuffle=self.shuffle,
        )
        epoch_no = None
        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()
            start = time.time()
            total_valid_duration = 0
            count = self.batch_multiplier - 1

            if self.do_recognition:
                processed_gls_tokens = self.total_gls_tokens
                epoch_recognition_loss = 0
            if self.do_translation:
                processed_txt_tokens = self.total_txt_tokens
                epoch_translation_loss = 0

            for batch in iter(train_iter):
                # reactivate training
                # create a Batch object from torchtext batch
                batch = Batch(
                    is_train=True,
                    torch_batch=batch,
                    txt_pad_index=self.txt_pad_index,
                    sgn_dim=self.feature_size,
                    features_dim=self.feature_size_features,
                    use_cuda=self.use_cuda,
                    frame_subsampling_ratio=self.frame_subsampling_ratio,
                    random_frame_subsampling=self.random_frame_subsampling,
                    random_frame_masking_ratio=self.random_frame_masking_ratio,
                )

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = count == 0

                recognition_loss, translation_loss = self._train_batch(
                    batch, update=update)

                if self.do_recognition:
                    self.tb_writer.add_scalar("train/train_recognition_loss",
                                              recognition_loss, self.steps)
                    epoch_recognition_loss += recognition_loss.detach().cpu(
                    ).numpy()

                if self.do_translation:
                    self.tb_writer.add_scalar("train/train_translation_loss",
                                              translation_loss, self.steps)
                    epoch_translation_loss += translation_loss.detach().cpu(
                    ).numpy()

                count = self.batch_multiplier if update else count
                count -= 1

                if (self.scheduler is not None
                        and self.scheduler_step_at == "step" and update):
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration

                    log_out = "[Epoch: {:03d} Step: {:08d}] ".format(
                        epoch_no + 1,
                        self.steps,
                    )

                    if self.do_recognition:
                        elapsed_gls_tokens = (self.total_gls_tokens -
                                              processed_gls_tokens)
                        processed_gls_tokens = self.total_gls_tokens
                        log_out += "Batch Recognition Loss: {:10.6f} => ".format(
                            recognition_loss)
                        log_out += "Gls Tokens per Sec: {:8.0f} || ".format(
                            elapsed_gls_tokens / elapsed)
                    if self.do_translation:
                        elapsed_txt_tokens = (self.total_txt_tokens -
                                              processed_txt_tokens)
                        processed_txt_tokens = self.total_txt_tokens
                        log_out += "Batch Translation Loss: {:10.6f} => ".format(
                            translation_loss)
                        log_out += "Txt Tokens per Sec: {:8.0f} || ".format(
                            elapsed_txt_tokens / elapsed)
                    log_out += "Lr: {:.6f}".format(
                        self.optimizer.param_groups[0]["lr"])
                    self.logger.info(log_out)
                    start = time.time()
                    total_valid_duration = 0

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()
                    # TODO (Cihan): There must be a better way of passing
                    #   these recognition only and translation only parameters!
                    #   Maybe have a NamedTuple with optional fields?
                    #   Hmm... Future Cihan's problem.
                    val_res = validate_on_data(
                        model=self.model,
                        data=valid_data,
                        batch_size=self.eval_batch_size,
                        use_cuda=self.use_cuda,
                        batch_type=self.eval_batch_type,
                        dataset_version=self.dataset_version,
                        sgn_dim=self.feature_size,
                        features_dim=self.feature_size_features,
                        txt_pad_index=self.txt_pad_index,
                        # Recognition Parameters
                        do_recognition=self.do_recognition,
                        recognition_loss_function=self.
                        recognition_loss_function
                        if self.do_recognition else None,
                        recognition_loss_weight=self.recognition_loss_weight
                        if self.do_recognition else None,
                        recognition_beam_size=self.eval_recognition_beam_size
                        if self.do_recognition else None,
                        # Translation Parameters
                        do_translation=self.do_translation,
                        translation_loss_function=self.
                        translation_loss_function
                        if self.do_translation else None,
                        translation_max_output_length=self.
                        translation_max_output_length
                        if self.do_translation else None,
                        level=self.level if self.do_translation else None,
                        translation_loss_weight=self.translation_loss_weight
                        if self.do_translation else None,
                        translation_beam_size=self.eval_translation_beam_size
                        if self.do_translation else None,
                        translation_beam_alpha=self.eval_translation_beam_alpha
                        if self.do_translation else None,
                        frame_subsampling_ratio=self.frame_subsampling_ratio,
                    )
                    self.model.train()

                    if self.do_recognition:
                        # Log Losses and ppl
                        self.tb_writer.add_scalar(
                            "valid/valid_recognition_loss",
                            val_res["valid_recognition_loss"],
                            self.steps,
                        )
                        self.tb_writer.add_scalar(
                            "valid/wer", val_res["valid_scores"]["wer"],
                            self.steps)
                        self.tb_writer.add_scalars(
                            "valid/wer_scores",
                            val_res["valid_scores"]["wer_scores"],
                            self.steps,
                        )

                    if self.do_translation:
                        self.tb_writer.add_scalar(
                            "valid/valid_translation_loss",
                            val_res["valid_translation_loss"],
                            self.steps,
                        )
                        self.tb_writer.add_scalar("valid/valid_ppl",
                                                  val_res["valid_ppl"],
                                                  self.steps)

                        # Log Scores
                        self.tb_writer.add_scalar(
                            "valid/chrf", val_res["valid_scores"]["chrf"],
                            self.steps)
                        self.tb_writer.add_scalar(
                            "valid/rouge", val_res["valid_scores"]["rouge"],
                            self.steps)
                        self.tb_writer.add_scalar(
                            "valid/bleu", val_res["valid_scores"]["bleu"],
                            self.steps)
                        self.tb_writer.add_scalars(
                            "valid/bleu_scores",
                            val_res["valid_scores"]["bleu_scores"],
                            self.steps,
                        )

                    if self.early_stopping_metric == "recognition_loss":
                        assert self.do_recognition
                        ckpt_score = val_res["valid_recognition_loss"]
                    elif self.early_stopping_metric == "translation_loss":
                        assert self.do_translation
                        ckpt_score = val_res["valid_translation_loss"]
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        assert self.do_translation
                        ckpt_score = val_res["valid_ppl"]
                    else:
                        ckpt_score = val_res["valid_scores"][self.eval_metric]

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_all_ckpt_scores = val_res["valid_scores"]
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            "Hooray! New best validation result [%s]!",
                            self.early_stopping_metric,
                        )
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if (self.scheduler is not None
                            and self.scheduler_step_at == "validation"):
                        prev_lr = self.scheduler.optimizer.param_groups[0][
                            "lr"]
                        self.scheduler.step(ckpt_score)
                        now_lr = self.scheduler.optimizer.param_groups[0]["lr"]

                        if prev_lr != now_lr:
                            if self.last_best_lr != prev_lr:
                                self.stop = True

                    # append to validation report
                    self._add_report(
                        valid_scores=val_res["valid_scores"],
                        valid_recognition_loss=val_res["valid_recognition_loss"]
                        if self.do_recognition else None,
                        valid_translation_loss=val_res["valid_translation_loss"]
                        if self.do_translation else None,
                        valid_ppl=val_res["valid_ppl"]
                        if self.do_translation else None,
                        eval_metric=self.eval_metric,
                        new_best=new_best,
                    )
                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        "Validation result at epoch %3d, step %8d: duration: %.4fs\n\t"
                        "Recognition Beam Size: %d\t"
                        "Translation Beam Size: %d\t"
                        "Translation Beam Alpha: %d\n\t"
                        "Recognition Loss: %4.5f\t"
                        "Translation Loss: %4.5f\t"
                        "PPL: %4.5f\n\t"
                        "Eval Metric: %s\n\t"
                        "WER %3.2f\t(DEL: %3.2f,\tINS: %3.2f,\tSUB: %3.2f)\n\t"
                        "BLEU-4 %.2f\t(BLEU-1: %.2f,\tBLEU-2: %.2f,\tBLEU-3: %.2f,\tBLEU-4: %.2f)\n\t"
                        "CHRF %.2f\t"
                        "ROUGE %.2f",
                        epoch_no + 1,
                        self.steps,
                        valid_duration,
                        self.eval_recognition_beam_size
                        if self.do_recognition else -1,
                        self.eval_translation_beam_size
                        if self.do_translation else -1,
                        self.eval_translation_beam_alpha
                        if self.do_translation else -1,
                        val_res["valid_recognition_loss"]
                        if self.do_recognition else -1,
                        val_res["valid_translation_loss"]
                        if self.do_translation else -1,
                        val_res["valid_ppl"] if self.do_translation else -1,
                        self.eval_metric.upper(),
                        # WER
                        val_res["valid_scores"]["wer"]
                        if self.do_recognition else -1,
                        val_res["valid_scores"]["wer_scores"]["del_rate"]
                        if self.do_recognition else -1,
                        val_res["valid_scores"]["wer_scores"]["ins_rate"]
                        if self.do_recognition else -1,
                        val_res["valid_scores"]["wer_scores"]["sub_rate"]
                        if self.do_recognition else -1,
                        # BLEU
                        val_res["valid_scores"]["bleu"]
                        if self.do_translation else -1,
                        val_res["valid_scores"]["bleu_scores"]["bleu1"]
                        if self.do_translation else -1,
                        val_res["valid_scores"]["bleu_scores"]["bleu2"]
                        if self.do_translation else -1,
                        val_res["valid_scores"]["bleu_scores"]["bleu3"]
                        if self.do_translation else -1,
                        val_res["valid_scores"]["bleu_scores"]["bleu4"]
                        if self.do_translation else -1,
                        # Other
                        val_res["valid_scores"]["chrf"]
                        if self.do_translation else -1,
                        val_res["valid_scores"]["rouge"]
                        if self.do_translation else -1,
                    )

                    self._log_examples(
                        sequences=[s for s in valid_data.sequence],
                        gls_references=val_res["gls_ref"]
                        if self.do_recognition else None,
                        gls_hypotheses=val_res["gls_hyp"]
                        if self.do_recognition else None,
                        txt_references=val_res["txt_ref"]
                        if self.do_translation else None,
                        txt_hypotheses=val_res["txt_hyp"]
                        if self.do_translation else None,
                    )

                    valid_seq = [s for s in valid_data.sequence]
                    # store validation set outputs and references
                    if self.do_recognition:
                        self._store_outputs("dev.hyp.gls", valid_seq,
                                            val_res["gls_hyp"], "gls")
                        self._store_outputs("references.dev.gls", valid_seq,
                                            val_res["gls_ref"])

                    if self.do_translation:
                        self._store_outputs("dev.hyp.txt", valid_seq,
                                            val_res["txt_hyp"], "txt")
                        self._store_outputs("references.dev.txt", valid_seq,
                                            val_res["txt_ref"])

                if self.stop:
                    break
            if self.stop:
                if (self.scheduler is not None
                        and self.scheduler_step_at == "validation"
                        and self.last_best_lr != prev_lr):
                    self.logger.info(
                        "Training ended since there were no improvements in"
                        "the last learning rate step: %f",
                        prev_lr,
                    )
                else:
                    self.logger.info(
                        "Training ended since minimum lr %f was reached.",
                        self.learning_rate_min,
                    )
                break

            self.logger.info(
                "Epoch %3d: Total Training Recognition Loss %.2f "
                " Total Training Translation Loss %.2f ",
                epoch_no + 1,
                epoch_recognition_loss if self.do_recognition else -1,
                epoch_translation_loss if self.do_translation else -1,
            )
        else:
            self.logger.info("Training ended after %3d epochs.", epoch_no + 1)
        self.logger.info(
            "Best validation result at step %8d: %6.2f %s.",
            self.best_ckpt_iteration,
            self.best_ckpt_score,
            self.early_stopping_metric,
        )

        self.tb_writer.close()  # close Tensorboard writer

    def _train_batch(self,
                     batch: Batch,
                     update: bool = True) -> (Tensor, Tensor):
        """
        Train the model on one batch: Compute the loss, make a gradient step.

        :param batch: training batch
        :param update: if False, only store gradient. if True also make update
        :return normalized_recognition_loss: Normalized recognition loss
        :return normalized_translation_loss: Normalized translation loss
        """

        recognition_loss, translation_loss = self.model.get_loss_for_batch(
            batch=batch,
            recognition_loss_function=self.recognition_loss_function
            if self.do_recognition else None,
            translation_loss_function=self.translation_loss_function
            if self.do_translation else None,
            recognition_loss_weight=self.recognition_loss_weight
            if self.do_recognition else None,
            translation_loss_weight=self.translation_loss_weight
            if self.do_translation else None,
        )

        # normalize translation loss
        if self.do_translation:
            if self.translation_normalization_mode == "batch":
                txt_normalization_factor = batch.num_seqs
            elif self.translation_normalization_mode == "tokens":
                txt_normalization_factor = batch.num_txt_tokens
            else:
                raise NotImplementedError(
                    "Only normalize by 'batch' or 'tokens'")

            # division needed since loss.backward sums the gradients until updated
            normalized_translation_loss = translation_loss / (
                txt_normalization_factor * self.batch_multiplier)
        else:
            normalized_translation_loss = 0

        # TODO (Cihan): Add Gloss Token normalization (?)
        #   I think they are already being normalized by batch
        #   I need to think about if I want to normalize them by token.
        if self.do_recognition:
            normalized_recognition_loss = recognition_loss / self.batch_multiplier
        else:
            normalized_recognition_loss = 0

        total_loss = normalized_recognition_loss + normalized_translation_loss
        # compute gradients
        total_loss.backward()

        if self.clip_grad_fun is not None:
            # clip gradients (in-place)
            self.clip_grad_fun(params=self.model.parameters())

        if update:
            # make gradient step
            self.optimizer.step()
            self.optimizer.zero_grad()

            # increment step counter
            self.steps += 1

        # increment token counter
        if self.do_recognition:
            self.total_gls_tokens += batch.num_gls_tokens
        if self.do_translation:
            self.total_txt_tokens += batch.num_txt_tokens

        return normalized_recognition_loss, normalized_translation_loss

    def _add_report(
        self,
        valid_scores: Dict,
        valid_recognition_loss: float,
        valid_translation_loss: float,
        valid_ppl: float,
        eval_metric: str,
        new_best: bool = False,
    ) -> None:
        """
        Append a one-line report to validation logging file.

        :param valid_scores: Dictionary of validation scores
        :param valid_recognition_loss: validation loss (sum over whole validation set)
        :param valid_translation_loss: validation loss (sum over whole validation set)
        :param valid_ppl: validation perplexity
        :param eval_metric: evaluation metric, e.g. "bleu"
        :param new_best: whether this is a new best model
        """
        current_lr = -1
        # ignores other param groups for now
        for param_group in self.optimizer.param_groups:
            current_lr = param_group["lr"]

        if new_best:
            self.last_best_lr = current_lr

        if current_lr < self.learning_rate_min:
            self.stop = True

        with open(self.valid_report_file, "a",
                  encoding="utf-8") as opened_file:
            opened_file.write(
                "Steps: {}\t"
                "Recognition Loss: {:.5f}\t"
                "Translation Loss: {:.5f}\t"
                "PPL: {:.5f}\t"
                "Eval Metric: {}\t"
                "WER {:.2f}\t(DEL: {:.2f},\tINS: {:.2f},\tSUB: {:.2f})\t"
                "BLEU-4 {:.2f}\t(BLEU-1: {:.2f},\tBLEU-2: {:.2f},\tBLEU-3: {:.2f},\tBLEU-4: {:.2f})\t"
                "CHRF {:.2f}\t"
                "ROUGE {:.2f}\t"
                "LR: {:.8f}\t{}\n".format(
                    self.steps,
                    valid_recognition_loss if self.do_recognition else -1,
                    valid_translation_loss if self.do_translation else -1,
                    valid_ppl if self.do_translation else -1,
                    eval_metric,
                    # WER
                    valid_scores["wer"] if self.do_recognition else -1,
                    valid_scores["wer_scores"]["del_rate"]
                    if self.do_recognition else -1,
                    valid_scores["wer_scores"]["ins_rate"]
                    if self.do_recognition else -1,
                    valid_scores["wer_scores"]["sub_rate"]
                    if self.do_recognition else -1,
                    # BLEU
                    valid_scores["bleu"] if self.do_translation else -1,
                    valid_scores["bleu_scores"]["bleu1"]
                    if self.do_translation else -1,
                    valid_scores["bleu_scores"]["bleu2"]
                    if self.do_translation else -1,
                    valid_scores["bleu_scores"]["bleu3"]
                    if self.do_translation else -1,
                    valid_scores["bleu_scores"]["bleu4"]
                    if self.do_translation else -1,
                    # Other
                    valid_scores["chrf"] if self.do_translation else -1,
                    valid_scores["rouge"] if self.do_translation else -1,
                    current_lr,
                    "*" if new_best else "",
                ))

    def _log_parameters_list(self) -> None:
        """
        Write all model parameters (name, shape) to the log.
        """
        model_parameters = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        n_params = sum([np.prod(p.size()) for p in model_parameters])
        self.logger.info("Total params: %d", n_params)
        trainable_params = [
            n for (n, p) in self.model.named_parameters() if p.requires_grad
        ]
        self.logger.info("Trainable parameters: %s", sorted(trainable_params))
        assert trainable_params

    def _log_examples(
        self,
        sequences: List[str],
        gls_references: List[str],
        gls_hypotheses: List[str],
        txt_references: List[str],
        txt_hypotheses: List[str],
    ) -> None:
        """
        Log `self.num_valid_log` number of samples from valid.

        :param sequences: sign video sequence names (list of strings)
        :param txt_hypotheses: decoded txt hypotheses (list of strings)
        :param txt_references: decoded txt references (list of strings)
        :param gls_hypotheses: decoded gls hypotheses (list of strings)
        :param gls_references: decoded gls references (list of strings)
        """

        if self.do_recognition:
            assert len(gls_references) == len(gls_hypotheses)
            num_sequences = len(gls_hypotheses)
        if self.do_translation:
            assert len(txt_references) == len(txt_hypotheses)
            num_sequences = len(txt_hypotheses)

        rand_idx = np.sort(
            np.random.permutation(num_sequences)[:self.num_valid_log])
        self.logger.info("Logging Recognition and Translation Outputs")
        self.logger.info("=" * 120)
        for ri in rand_idx:
            self.logger.info("Logging Sequence: %s", sequences[ri])
            if self.do_recognition:
                gls_res = wer_single(r=gls_references[ri],
                                     h=gls_hypotheses[ri])
                self.logger.info("\tGloss Reference :\t%s",
                                 gls_res["alignment_out"]["align_ref"])
                self.logger.info("\tGloss Hypothesis:\t%s",
                                 gls_res["alignment_out"]["align_hyp"])
                self.logger.info("\tGloss Alignment :\t%s",
                                 gls_res["alignment_out"]["alignment"])
            if self.do_recognition and self.do_translation:
                self.logger.info("\t" + "-" * 116)
            if self.do_translation:
                txt_res = wer_single(r=txt_references[ri],
                                     h=txt_hypotheses[ri])
                self.logger.info("\tText Reference  :\t%s",
                                 txt_res["alignment_out"]["align_ref"])
                self.logger.info("\tText Hypothesis :\t%s",
                                 txt_res["alignment_out"]["align_hyp"])
                self.logger.info("\tText Alignment  :\t%s",
                                 txt_res["alignment_out"]["alignment"])
            self.logger.info("=" * 120)

    def _store_outputs(self,
                       tag: str,
                       sequence_ids: List[str],
                       hypotheses: List[str],
                       sub_folder=None) -> None:
        """
        Write current validation outputs to file in `self.model_dir.`

        :param hypotheses: list of strings
        """
        if sub_folder:
            out_folder = os.path.join(self.model_dir, sub_folder)
            if not os.path.exists(out_folder):
                os.makedirs(out_folder)
            current_valid_output_file = "{}/{}.{}".format(
                out_folder, self.steps, tag)
        else:
            out_folder = self.model_dir
            current_valid_output_file = "{}/{}".format(out_folder, tag)

        with open(current_valid_output_file, "w",
                  encoding="utf-8") as opened_file:
            for seq, hyp in zip(sequence_ids, hypotheses):
                opened_file.write("{}|{}\n".format(seq, hyp))
예제 #11
0
파일: dqn.py 프로젝트: sudo-michael/cleanrl
    # TRY NOT TO MODIFY: execute the game and log data.
    next_obs, reward, done, _ = env.step(action)
    episode_reward += reward

    # ALGO LOGIC: training.
    rb.put((obs, action, reward, next_obs, done))
    if global_step > args.learning_starts and global_step % args.train_frequency == 0:
        s_obs, s_actions, s_rewards, s_next_obses, s_dones = rb.sample(args.batch_size)
        with torch.no_grad():
            target_max = torch.max(target_network.forward(s_next_obses, device), dim=1)[0]
            td_target = torch.Tensor(s_rewards).to(device) + args.gamma * target_max * (1 - torch.Tensor(s_dones).to(device))
        old_val = q_network.forward(s_obs, device).gather(1, torch.LongTensor(s_actions).view(-1,1).to(device)).squeeze()
        loss = loss_fn(td_target, old_val)

        if global_step % 100 == 0:
            writer.add_scalar("losses/td_loss", loss, global_step)

        # optimize the midel
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(list(q_network.parameters()), args.max_grad_norm)
        optimizer.step()

        # update the target network
        if global_step % args.target_network_frequency == 0:
            target_network.load_state_dict(q_network.state_dict())

    # TRY NOT TO MODIFY: CRUCIAL step easy to overlook 
    obs = next_obs

    if done:
예제 #12
0
def train_melgan(args):
    root = Path(args.save_path)
    load_root = Path(args.load_path) if args.load_path else None
    root.mkdir(parents=True, exist_ok=True)

    metadata_dir = root.joinpath('metadata')
    metadata_dir.mkdir(exist_ok=True)

    ####################################
    # Dump arguments and create logger #
    ####################################
    with open(metadata_dir / "args.yml", "w") as f:
        yaml.dump(args.__dict__, f)
    with open(metadata_dir / "args.json", "w", encoding="utf8") as f:
        json.dump(args.__dict__, f, indent=4, ensure_ascii=False)

    eventdir = root / "events"
    eventdir.mkdir(exist_ok=True)
    writer = SummaryWriter(str(eventdir))

    #######################
    # Load PyTorch Models #
    #######################
    ratios = [int(w) for w in args.ratios.split()]
    netG = Generator(args.n_mel_channels,
                     args.ngf,
                     args.n_residual_layers,
                     ratios=ratios).to(_device)
    netD = Discriminator(args.num_D, args.ndf, args.n_layers_D,
                         args.downsamp_factor).to(_device)
    # fft = Audio2Mel(n_mel_channels=args.n_mel_channels).to(_device)
    if args.mode == 'default':
        fft = audio2mel
    elif args.mode == 'synthesizer':
        fft = audio2mel_synthesizer
    elif args.mode == 'mellotron':
        fft = audio2mel_mellotron
    else:
        raise KeyError
    # print(netG)
    # print(netD)

    #####################
    # Create optimizers #
    #####################
    optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))

    if load_root and load_root.exists():
        netG.load_state_dict(torch.load(load_root))
        # optG.load_state_dict(torch.load(load_root / "optG.pt"))
        # netD.load_state_dict(torch.load(load_root / "netD.pt"))
        # optD.load_state_dict(torch.load(load_root / "optD.pt"))

    #######################
    # Create data loaders #
    #######################
    train_set = AudioDataset(Path(args.data_path),
                             args.seq_len,
                             sampling_rate=args.sample_rate)
    test_set = AudioDataset(
        Path(args.data_path),  # test file
        args.sample_rate * 4,
        sampling_rate=args.sample_rate,
        augment=False,
    )

    # 保存训练数据
    with open(metadata_dir.joinpath('train.yml'), 'wt',
              encoding='utf8') as fout:
        yaml.dump([str(w.absolute()) for w in train_set.audio_files],
                  fout,
                  default_flow_style=False,
                  encoding='utf-8',
                  allow_unicode=True)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=args.dataloader_num_workers,
                              shuffle=True)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=True)

    ##########################
    # Dumping original audio #
    ##########################
    test_voc = []
    test_audio = []
    for i, x_t in enumerate(test_loader):
        x_t = x_t.to(_device)
        s_t = fft(x_t).detach()

        test_voc.append(s_t.to(_device))
        test_audio.append(x_t)

        audio = x_t.squeeze().cpu()
        oridir = root / "original"
        oridir.mkdir(exist_ok=True)
        save_sample(oridir / ("original_{}_{}.wav".format("test", i)),
                    args.sample_rate, audio)
        writer.add_audio("original/{}/sample_{}.wav".format("test", i),
                         audio,
                         0,
                         sample_rate=args.sample_rate)
        mel_outputs = fft(x_t)
        writer.add_image("original/{}/sample_{}.npy".format("test", i),
                         plot_spectrogram_to_numpy(
                             mel_outputs[0].data.cpu().numpy()),
                         0,
                         dataformats='HWC')
        if i == args.n_test_samples - 1:
            break

    costs = []
    start = time.time()

    # enable cudnn autotuner to speed up training
    torch.backends.cudnn.benchmark = True

    best_mel_reconst = 1000000
    step_begin = args.start_step
    look_steps = {
        step_begin + 10, step_begin + 100, step_begin + 1000,
        step_begin + 10000
    }
    steps = step_begin
    for epoch in range(1, args.epochs + 1):
        print("\nEpoch {} beginning. Current step: {}".format(epoch, steps))
        for iterno, x_t in enumerate(
                tqdm(train_loader, desc=f"Epoch-{epoch}", ncols=100)):
            # torch.Size([4, 1, 8192]) torch.Size([4, 80, 32])
            # 8192 = 32 x 256
            x_t = x_t.to(_device)
            s_t = fft(x_t).detach()
            x_pred_t = netG(s_t.to(_device))

            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach())
                s_error = F.l1_loss(s_t, s_pred_t).item()

            #######################
            # Train Discriminator #
            #######################
            D_fake_det = netD(x_pred_t.to(_device).detach())
            D_real = netD(x_t.to(_device))

            loss_D = 0
            for scale in D_fake_det:
                loss_D += F.relu(1 + scale[-1]).mean()

            for scale in D_real:
                loss_D += F.relu(1 - scale[-1]).mean()

            netD.zero_grad()
            loss_D.backward()
            optD.step()

            ###################
            # Train Generator #
            ###################
            D_fake = netD(x_pred_t.to(_device))

            loss_G = 0
            for scale in D_fake:
                loss_G += -scale[-1].mean()

            loss_feat = 0
            feat_weights = 4.0 / (args.n_layers_D + 1)
            D_weights = 1.0 / args.num_D
            wt = D_weights * feat_weights
            for i in range(args.num_D):
                for j in range(len(D_fake[i]) - 1):
                    loss_feat += wt * F.l1_loss(D_fake[i][j],
                                                D_real[i][j].detach())

            netG.zero_grad()
            (loss_G + args.lambda_feat * loss_feat).backward()
            optG.step()

            ######################
            # Update tensorboard #
            ######################

            costs.append(
                [loss_D.item(),
                 loss_G.item(),
                 loss_feat.item(), s_error])
            steps += 1
            writer.add_scalar("loss/discriminator", costs[-1][0], steps)
            writer.add_scalar("loss/generator", costs[-1][1], steps)
            writer.add_scalar("loss/feature_matching", costs[-1][2], steps)
            writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps)

            if steps % args.save_interval == 0 or steps in look_steps:
                st = time.time()
                with torch.no_grad():
                    for i, (voc,
                            true_audio) in enumerate(zip(test_voc,
                                                         test_audio)):
                        pred_audio_ = netG(voc)
                        pred_audio = pred_audio_.squeeze().cpu()
                        gendir = root / "generated"
                        gendir.mkdir(exist_ok=True)
                        save_sample(
                            gendir /
                            ("generated_step{}_{}.wav".format(steps, i)),
                            args.sample_rate, pred_audio)
                        writer.add_audio(
                            "generated/step{}/sample_{}.wav".format(steps, i),
                            pred_audio,
                            epoch,
                            sample_rate=args.sample_rate,
                        )
                        # 查看频谱,直观了解生成语音的情况
                        mel_outputs = fft(pred_audio_.detach())
                        writer.add_image(
                            "generated/step{}/sample_{}.npy".format(steps, i),
                            plot_spectrogram_to_numpy(
                                mel_outputs[0].data.cpu().numpy()),
                            epoch,
                            dataformats='HWC')

                ptdir = root / "models"
                ptdir.mkdir(exist_ok=True)
                torch.save(netG.state_dict(),
                           ptdir / "step{}_netG.pt".format(steps))
                torch.save(optG.state_dict(),
                           ptdir / "step{}_optG.pt".format(steps))

                torch.save(netD.state_dict(),
                           ptdir / "step{}_netD.pt".format(steps))
                torch.save(optD.state_dict(),
                           ptdir / "step{}_optD.pt".format(steps))

                if (np.asarray(costs).mean(0)[-1] < best_mel_reconst) or (
                        steps % (args.save_interval * 10) == 0):
                    best_mel_reconst = np.asarray(costs).mean(0)[-1]
                    torch.save(netD,
                               ptdir / "best_step{}_netD.pt".format(steps))
                    torch.save(netG,
                               ptdir / "best_step{}_netG.pt".format(steps))

                # print("\nTook %5.4fs to generate samples" % (time.time() - st))
                # print("-" * 100)

            if steps % args.log_interval == 0 or steps in look_steps:
                print(
                    "\nEpoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".
                    format(
                        epoch,
                        iterno,
                        len(train_loader),
                        1000 * (time.time() - start) / args.log_interval,
                        np.asarray(costs).mean(0),
                    ))
                costs = []
                start = time.time()
예제 #13
0
def train_and_eval_model(args, model, fv, ks, train_dataloader,
                         valid_dataloader, test_dataloader, device):
    ## Define loss criteria, optimizer and adaptive learning scheduler
    criterion = nn.MSELoss(reduction='mean')
    optimizer = optim.RMSprop(model.parameters(),
                              lr=args.learning_rate,
                              alpha=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=1,
                                                     threshold=0.01,
                                                     verbose=True)
    writer = SummaryWriter(
        "runs/" +
        f"BS={args.batchsize}_maxEp={args.maxepoch}_LR={args.learning_rate}_ks={ks}_model={fv}"
    )

    for epoch in range(1, args.maxepoch + 1):
        model.train()
        running_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            x_batch = batch[0].to(device)
            y_batch = batch[1].to(device)

            optimizer.zero_grad()  # zero the parameter gradients
            loss = criterion(model(x_batch), y_batch)
            loss.backward()  # backpropagate loss
            optimizer.step()  # update parameters
            running_loss += loss.item() / args.nbands

        print("Epoch = " + str(epoch) + "  :Total loss = " + str(running_loss))
        scheduler.step(running_loss)  #this adjusts the adaptive LR scheduler
        writer.add_scalar('Loss', running_loss, epoch)  # log training stats

    with torch.no_grad():
        model.eval()
        for step, batch in enumerate(valid_dataloader):
            x_batch = batch[0].to(device)
            y_batch = batch[1].to(device)
            pred_batch = model(x_batch)
            loss = criterion(pred_batch, y_batch)
            tloss = loss.item()

            floss = calfloss(pred_batch, y_batch)
            ploss = floss.item()

        writer.add_scalar('valid loss', tloss)
        writer.add_scalar('valid fractional loss', ploss)
        print('Total valid loss is ' + str(ploss))

        for step, batch in enumerate(test_dataloader):
            x_batch = batch[0].to(device)
            y_batch = batch[1].to(device)
            pred_batch = model(x_batch)
            loss = criterion(pred_batch, y_batch)
            tloss = loss.item()

            floss = calfloss(pred_batch, y_batch)
            ploss = floss.item()

        writer.add_scalar('test loss', tloss)
        writer.add_scalar('test fractional loss', ploss)
        print('Total test loss is ' + str(ploss))

    writer.close()
예제 #14
0
파일: Train.py 프로젝트: neuromorphs/l2race
def train_network():
    print('')
    print('')
    # Start measuring time - to evaluate performance of the training function
    start = timeit.default_timer()

    # Set seeds
    set_seed(args)

    # Make folders if not yet exist
    try:
        os.makedirs('save')
    except FileExistsError:
        pass

    # Save relevant arguments from args and set hardcoded arguments
    lr = args.lr  # learning rate
    batch_size = args.batch_size  # Mini-batch size
    num_epochs = args.num_epochs  # Number of epochs to train the network
    seq_len = args.seq_len

    # Network architecture:
    rnn_name = args.rnn_name
    inputs_list = args.inputs_list
    outputs_list = args.outputs_list

    load_rnn = args.load_rnn  # If specified this is the name of pretrained RNN which should be loaded
    path_save = args.path_save

    # Create rnn instance and update lists of input, outputs and its name (if pretraind net loaded)
    net, rnn_name, inputs_list, outputs_list \
        = create_rnn_instance(rnn_name, inputs_list, outputs_list, load_rnn, path_save, device)

    # Create log for this RNN and determine its full name
    rnn_full_name = create_log_file(rnn_name, inputs_list, outputs_list,
                                    path_save)

    ########################################################
    # Create Dataset
    ########################################################

    train_features, train_targets = load_data(args, args.train_file_name,
                                              inputs_list, outputs_list)
    dev_features, dev_targets = load_data(args, args.val_file_name,
                                          inputs_list, outputs_list)

    train_set = Dataset(train_features, train_targets, args)
    dev_set = Dataset(dev_features, dev_targets, args)
    print('Number of samples in training set: {}'.format(
        train_set.number_of_samples))
    print('The training sets sizes are: {}'.format(train_set.df_lengths))
    print('Number of samples in validation set: {}'.format(
        dev_set.number_of_samples))
    print('')

    plot_results(
        net=net,
        args=args,
        dataset=dev_set,
        filepath='../../data/oval_easy_12_rounds.csv',
        seq_len=400,
        comment='This is the network at the beginning of the training',
        inputs_list=inputs_list,
        outputs_list=outputs_list,
        rnn_full_name=rnn_full_name)

    # Create PyTorch dataloaders for train and dev set
    train_generator = data.DataLoader(dataset=train_set,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers)
    dev_generator = data.DataLoader(dataset=dev_set,
                                    batch_size=512,
                                    shuffle=False,
                                    num_workers=args.num_workers)

    # Print parameter count
    print_parameter_count(net)  # Seems not to function well

    # Select Optimizer
    optimizer = optim.Adam(net.parameters(), amsgrad=True, lr=lr)

    # TODO: Verify if scheduler is working. Try tweaking parameters of below scheduler and try cyclic lr scheduler

    # scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=lr, max_lr=0.1)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

    # Select Loss Function
    criterion = nn.MSELoss()  # Mean square error loss function
    '''
    Init Tensorboard
    '''
    comment = f' batch_size={batch_size} lr={lr} seq_len={seq_len}'
    tb = SummaryWriter(comment=comment)
    ########################################################
    # Training
    ########################################################
    print("Starting training...")
    print('')
    time.sleep(0.001)

    # Create dictionary to store training history
    dict_history = {}
    dict_history['epoch'] = []
    dict_history['time'] = []
    dict_history['lr'] = []
    dict_history['train_loss'] = []
    dict_history['dev_loss'] = []
    dict_history['dev_gain'] = []
    dict_history['test_loss'] = []
    dev_gain = 1

    # The epoch_saved variable will indicate from which epoch is the last RNN model,
    # which was good enough to be saved
    epoch_saved = -1
    for epoch in range(num_epochs):

        ###########################################################################################################
        # Training - Iterate batches
        ###########################################################################################################
        # Set RNN in training mode
        net = net.train()
        # Define variables accumulating training loss and counting training batchs
        train_loss = 0
        train_batches = 0

        # Iterate training over available batches
        # tqdm() is just a function which displays the progress bar
        # Otherwise the line below is the same as "for batch, labels in train_generator:"
        for batch, labels in tqdm(train_generator):  # Iterate through batches

            # Reset the network (internal states of hidden layers and output history not the weights!)
            net.reset()

            # Further modifying the input and output form to fit RNN requirements
            # If GPU available we send tensors to GPU (cuda)
            if torch.cuda.is_available():
                batch = batch.float().cuda().transpose(0, 1)
                labels = labels.float().cuda()
            else:
                batch = batch.float().transpose(0, 1)
                labels = labels.float()

            # # Reset memory of gradients
            # optimizer.zero_grad()

            # Warm-up (open loop prediction) to settle the internal state of RNN hidden layers
            net(rnn_input=batch[:args.warm_up_len, :, :])

            # Reset memory of gradients
            optimizer.zero_grad()

            # Forward propagation - These are the results from which we calculate the update to RNN weights
            # GRU Input size must be (seq_len, batch, input_size)
            net(rnn_input=batch[args.warm_up_len:, :, :])
            out = net.return_outputs_history()

            # Get loss
            loss = criterion(out[:, args.warm_up_len:, :],
                             labels[:, args.warm_up_len:, :])

            # Backward propagation
            loss.backward()

            # Gradient clipping - prevent gradient from exploding
            torch.nn.utils.clip_grad_norm_(net.parameters(), 100)

            # Update parameters
            optimizer.step()
            scheduler.step()
            # Update variables for loss calculation
            batch_loss = loss.detach()
            train_loss += batch_loss  # Accumulate loss
            train_batches += 1  # Accumulate count so we can calculate mean later

        ###########################################################################################################
        # Validation - Iterate batches
        ###########################################################################################################

        # Set the network in evaluation mode
        net = net.eval()

        # Define variables accumulating evaluation loss and counting evaluation batches
        dev_loss = 0
        dev_batches = 0

        for (batch, labels) in tqdm(dev_generator):

            # Reset the network (internal states of hidden layers and output history not the weights!)
            net.reset()

            # Further modifying the input and output form to fit RNN requirements
            # If GPU available we send tensors to GPU (cuda)
            if torch.cuda.is_available():
                batch = batch.float().cuda().transpose(0, 1)
                labels = labels.float().cuda()
            else:
                batch = batch.float().transpose(0, 1)
                labels = labels.float()

            # Warm-up (open loop prediction) to settle the internal state of RNN hidden layers
            net(rnn_input=batch)
            out = net.return_outputs_history()

            # Get loss
            # For evaluation we always calculate loss over the whole maximal prediction period
            # This allow us to compare RNN models from different epochs
            loss = criterion(out[:, args.warm_up_len:args.seq_len],
                             labels[:, args.warm_up_len:args.seq_len])

            # Update variables for loss calculation
            batch_loss = loss.detach()
            dev_loss += batch_loss  # Accumulate loss
            dev_batches += 1  # Accumulate count so we can calculate mean later

        # Reset the network (internal states of hidden layers and output history not the weights!)
        net.reset()
        # Get current learning rate
        # TODO(Fixed. It does changes now): I think now the learning rate do not change during traing, or it is not a right way to get this info.

        for param_group in optimizer.param_groups:
            lr_curr = param_group['lr']
        '''
        Add data for tensorboard
        TODO : Add network graph and I/O to tensorboard
        '''
        # tb.add_graph(net)
        tb.add_scalar('Train Loss', train_loss / train_batches, epoch)
        tb.add_scalar('Dev Loss', dev_loss / dev_batches, epoch)

        # Add the first sample of batch to tensorboard. Prediction is represented by Dotted line
        # TODO: Concatenate such graphs. But they are not continous
        for i in range(labels.shape[2]):
            time_label = np.arange(0, labels.shape[1], 1)
            time_out = np.arange(0, out.shape[1], 1)
            true_data = labels[1, :, i]
            predicted_data = out[1, :, i]
            fig_tb = plt.figure(5)
            plt.plot(time_label, true_data.detach().cpu())
            plt.plot(time_out,
                     predicted_data.detach().cpu(),
                     linestyle='dashed')
            tb.add_figure(tag=str(args.outputs_list[i]),
                          figure=fig_tb,
                          global_step=epoch)

        for name, param in net.named_parameters():
            tb.add_histogram(name, param, epoch)
            tb.add_histogram(f'{name}.grad', param.grad, epoch)
        tb.close()

        # Write the summary information about the training for the just completed epoch to a dictionary

        dict_history['epoch'].append(epoch)
        dict_history['lr'].append(lr_curr)
        dict_history['train_loss'].append(train_loss.detach().cpu().numpy() /
                                          train_batches /
                                          (args.seq_len - args.warm_up_len))
        dict_history['dev_loss'].append(dev_loss.detach().cpu().numpy() /
                                        dev_batches /
                                        (args.seq_len - args.warm_up_len))

        # Get relative loss gain for network evaluation
        if epoch >= 1:
            dev_gain = (dict_history['dev_loss'][epoch - 1] - dict_history['dev_loss'][epoch]) / \
                       dict_history['dev_loss'][epoch - 1]
        dict_history['dev_gain'].append(dev_gain)

        # Print the summary information about the training for the just completed epoch
        print('\nEpoch: %3d of %3d | '
              'LR: %1.5f | '
              'Train-L: %6.4f | '
              'Val-L: %6.4f | '
              'Val-Gain: %3.2f |' %
              (dict_history['epoch'][epoch], num_epochs - 1,
               dict_history['lr'][epoch], dict_history['train_loss'][epoch],
               dict_history['dev_loss'][epoch],
               dict_history['dev_gain'][epoch] * 100))
        print('')

        # Save the best model with the lowest dev loss
        # Always save the model from epoch 0
        # TODO: this is a bug: you should only save the model from epoch 0 if there is no pretraind network
        if epoch == 0:
            min_dev_loss = dev_loss
        # If current loss smaller equal than minimal till now achieved loss,
        # save the current RNN model and save its loss as minimal ever achieved
        if dev_loss <= min_dev_loss:
            epoch_saved = epoch
            min_dev_loss = dev_loss
            torch.save(net.state_dict(),
                       args.path_save + rnn_full_name + '.pt',
                       _use_new_zipfile_serialization=False)
            print('>>> saving best model from epoch {}'.format(epoch))
            print('')
        else:
            print('>>> We keep model from epoch {}'.format(epoch_saved))
            print('')

        plot_string = 'This is the network after {} training epoch'.format(
            epoch + 1)
        plot_results(net=net,
                     args=args,
                     dataset=dev_set,
                     filepath='../../data/oval_easy_12_rounds.csv',
                     seq_len=600,
                     comment=plot_string,
                     inputs_list=inputs_list,
                     outputs_list=outputs_list,
                     rnn_full_name=rnn_full_name)
        # Evaluate the performance of the current network
        # by checking its predictions on a randomly generated CartPole experiment
        # plot_results(net, args, val_file)

    # When finished the training print the final message
    print(
        "Training Completed...                                               ")
    print(" ")

    # Calculate the total time it took to run the function
    stop = timeit.default_timer()
    total_time = stop - start

    # Return the total time it took to run the function
    return total_time
예제 #15
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """

    if args.local_rank in [-1, 0]:
        tensorboard_log_dir = os.path.join(
            "tensorboard", args.task_name, args.data_dir, "_".join([
                args.model_name_or_path,
                str(args.max_seq_length),
                str(
                    max(1, args.n_gpu) * args.gradient_accumulation_steps *
                    args.per_gpu_train_batch_size),
                str(args.learning_rate),
                str(args.weight_decay),
                str(args.warmup_steps)
            ]), str(args.seed))
        logger.info("Tensorboard dir: %s", tensorboard_log_dir)
        tb_writer = SummaryWriter(log_dir=tensorboard_log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist

    if args.resume:
        opt_path = os.path.join(args.model_name_or_path, "optimizer.pt")
        sch_path = os.path.join(args.model_name_or_path, "scheduler.pt")

        if os.path.isfile(opt_path) and os.path.isfile(sch_path):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(torch.load(opt_path))
            scheduler.load_state_dict(torch.load(sch_path))
        else:
            raise RuntimeError(
                f"--resume was set but there are no optimizer and scheduler states at {opt_path} and {sch_path}"
            )

    else:
        logger.info(
            "Not checking for optimizer and scheduler state as --resume was not set. Starting afresh"
        )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint

    if args.resume:
        if not args.global_step:
            raise ValueError(
                "--global_step (int) has to be set when using --resume")
        global_step = args.global_step
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)
        #
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])

        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training

            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1

                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }

            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}

                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)

                        for key, value in results.items():
                            logs[key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["lr"] = learning_rate_scalar
                    logs["train_loss"] = loss_scalar
                    logging_loss = tr_loss

                    logger.info(
                        "Performance at global step: %s",
                        str(global_step),
                    )

                    for key, value in logs.items():
                        logger.info("  %s = %s", key, str(value))
                        tb_writer.add_scalar(key, value, global_step)

                    if args.wandb:
                        wandb_log({**logs, **{"step": global_step}})

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))

                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()

                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()

            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
        
        for i in range(1, motor_input.shape[0]):
            visual_output, motor_output, hidden, mu, logsig =   rnn(visual_input[i].view(-1,1,64,64)*(1-pred_fb) + visual_output*pred_fb,
                                                                motor_input[i].view(-1,16)*(1-pred_fb) + motor_output.view(-1,16)*pred_fb,
                                                                hidden)

            motor_loss = criterion(motor_output.view(-1,16), motor_target[i].view(-1,16))
            visual_loss = criterion(visual_output, visual_target[i].view(-1,1,64,64))
            kl_loss = rnn.KL(mu, logsig)
            
            total_loss_ = visual_loss.item() + motor_loss.item() + kl_loss.item()
            visual_mult = total_loss_ / 3 / visual_loss.item()
            motor_mult = total_loss_ / 3 / motor_loss.item()
            Dkl_mult = total_loss_ / 3 / kl_loss.item()
            loss += visual_loss*visual_mult + motor_loss*motor_mult + kl_loss*Dkl_mult

        loss.backward()
        writer.add_scalar('motor loss', motor_loss, epoch)
        writer.add_scalar('visual loss', visual_loss, epoch)
            
        VLOSS.append(visual_loss.item())
        MLOSS.append(motor_loss.item())
        rnn.optimizer.step()
        printProgressBar(epoch + 1, EPOCHS, prefix='Epoch: {} vloss: {:.2f} mloss: {:.6f} Dkl: {:.3f}'.format(epoch, visual_loss.item(), motor_loss.item(), kl_loss.item()), suffix='Complete', length=25)
        if epoch % 1000 == 0:
            torch.save(rnn.state_dict(), 'checkpoint')

    except KeyboardInterrupt:
        print('\nKeyboard Interrupt')
        break
        
예제 #17
0
    def train(self) -> bool:
        """Run training in a separate thread (added to the global application ThreadPool)."""

        # Free memory on the GPU
        self._clear_session()

        # Check that the data is set properly
        if len(self._train_image_names) == 0 or \
                len(self._train_mask_names) == 0 or \
                len(self._validation_image_names) == 0 or \
                len(self._validation_mask_names) == 0:
            self._message = "No training/validation data found."
            return False

        if len(self._train_image_names) != len(self._train_mask_names) == 0:
            self._message = "The number of training images does not match the number of training masks."
            return False

        if len(self._validation_image_names) != len(
                self._validation_mask_names) == 0:
            self._message = "The number of validation images does not match the number of validation masks."
            return False

        # Define the transforms
        self._define_transforms()

        # Define the datasets and data loaders
        self._define_training_data_loaders()

        # Instantiate the model
        self._define_model()

        # Define the loss function
        self._define_training_loss()

        # Define the optimizer (with default parameters)
        self._define_optimizer()

        # Define the validation metric
        self._define_validation_metric()

        # Define experiment name and model name
        experiment_name, model_file_name = self._prepare_experiment_and_model_names(
        )

        # Keep track of the best model file name
        self._best_model = model_file_name

        # Enter the main training loop
        best_metric = -1
        best_metric_epoch = -1

        epoch_loss_values = list()
        metric_values = list()

        # Initialize TensorBoard's SummaryWriter
        writer = SummaryWriter(experiment_name)

        for epoch in range(self._n_epochs):

            # Inform
            self._print_header(f"Epoch {epoch + 1}/{self._n_epochs}")

            # Switch to training mode
            self._model.train()

            epoch_loss = 0
            step = 0
            for batch_data in self._train_dataloader:

                # Update step
                step += 1

                # Get the next batch and move it to device
                inputs, labels = batch_data[0].to(
                    self._device), batch_data[1].to(self._device)

                # Zero the gradient buffers
                self._optimizer.zero_grad()

                # Forward pass
                outputs = self._model(inputs)

                # Calculate the loss
                loss = self._training_loss_function(outputs, labels)

                # Back-propagate
                loss.backward()

                # Update weights (optimize)
                self._optimizer.step()

                # Update and store metrics
                epoch_loss += loss.item()
                epoch_len = len(
                    self._train_dataset) / self._train_dataloader.batch_size
                if epoch_len != int(epoch_len):
                    epoch_len = int(epoch_len) + 1

                print(
                    f"Batch {step}/{epoch_len}: train_loss = {loss.item():.4f}",
                    file=self._stdout)

            epoch_loss /= step
            epoch_loss_values.append(epoch_loss)
            print(f"Average loss = {epoch_loss:.4f}", file=self._stdout)
            writer.add_scalar("average_train_loss", epoch_loss, epoch + 1)

            # Validation
            if (epoch + 1) % self._validation_step == 0:

                self._print_header("Validation")

                # Switch to evaluation mode
                self._model.eval()

                # Make sure not to update the gradients
                with torch.no_grad():

                    # Global metrics
                    metric_sum = 0.0
                    metric_count = 0
                    metric = 0.0

                    # Keep track of the metrics for all classes
                    metric_sum_classes = self._out_channels * [0.0]
                    metric_count_classes = self._out_channels * [0]
                    metric_classes = self._out_channels * [0.0]

                    for val_data in self._validation_dataloader:

                        # Get the next batch and move it to device
                        val_images, val_labels = val_data[0].to(
                            self._device), val_data[1].to(self._device)

                        # Apply sliding inference over ROI size
                        val_outputs = sliding_window_inference(
                            val_images, self._roi_size,
                            self._sliding_window_batch_size, self._model)
                        val_outputs = self._validation_post_transforms(
                            val_outputs)

                        # Compute overall metric
                        value, not_nans = self._validation_metric(
                            y_pred=val_outputs, y=val_labels)
                        not_nans = not_nans.item()
                        metric_count += not_nans
                        metric_sum += value.item() * not_nans

                        # Compute metric for each class
                        for c in range(self._out_channels):
                            value_obj, not_nans = self._validation_metric(
                                y_pred=val_outputs[:, c:c + 1],
                                y=val_labels[:, c:c + 1])
                            not_nans = not_nans.item()
                            metric_count_classes[c] += not_nans
                            metric_sum_classes[c] += value_obj.item(
                            ) * not_nans

                    # Global metric
                    metric = metric_sum / metric_count
                    metric_values.append(metric)

                    # Metric per class
                    for c in range(self._out_channels):
                        metric_classes[c] = metric_sum_classes[
                            c] / metric_count_classes[c]

                    # Print summary
                    print(f"Global metric = {metric:.4f} ", file=self._stdout)
                    for c in range(self._out_channels):
                        print(
                            f"Class '{self._class_names[c]}' metric = {metric_classes[c]:.4f} ",
                            file=self._stdout)

                    # Do we have the best metric so far?
                    if metric > best_metric:
                        best_metric = metric
                        best_metric_epoch = epoch + 1
                        torch.save(self._model.state_dict(), model_file_name)
                        print(
                            f"New best global metric = {best_metric:.4f} at epoch: {best_metric_epoch}",
                            file=self._stdout)
                        print(
                            f"Saved best model '{Path(model_file_name).name}'",
                            file=self._stdout)

                    # Add validation loss and metrics to log
                    writer.add_scalar("val_mean_dice_loss", metric, epoch + 1)
                    for c in range(self._out_channels):
                        metric_name = f"val_{self._class_names[c].lower()}_metric"
                        writer.add_scalar(metric_name, metric_classes[c],
                                          epoch + 1)

        print(
            f"Training completed. Best_metric = {best_metric:.4f} at epoch: {best_metric_epoch}",
            file=self._stdout)
        writer.close()

        # Return success
        return True
예제 #18
0
                if param.requires_grad:
                    writer.add_histogram('Model/{}'.format(name), param, epoch)
        # Average the train / validation metrics
        train_loss = torch.mean(torch.tensor(epoch_train_loss))
        train_acc = torch.mean(torch.tensor(epoch_train_acc))
        val_loss = torch.mean(torch.tensor(epoch_val_loss))
        val_acc = torch.mean(torch.tensor(epoch_val_acc))

        if ((epoch + 1) % 100 == 0):
            print(f'\nEpoch {epoch} - Saving model...')
            model_path = os.path.join(logdir, 'model', 'state_dict.pth')
            opt_path = os.path.join(logdir, 'model', 'opt_state.pth')
            torch.save(net.state_dict(), model_path)
            torch.save(optimizer.state_dict(), opt_path)
            print('Done.')
    # ---------------------------------------------------------------------------
    # Log metrics
    writer.add_scalar('Train/Accuracy', train_acc, epoch)
    writer.add_scalar('Train/Loss', train_loss, epoch)
    writer.add_scalar('Validation/Accuracy', val_acc, epoch)
    writer.add_scalar('Validation/Loss', val_loss, epoch)
    writer.add_scalar('Learning Rate', scheduler.get_last_lr()[0], epoch)
    # ---------------------------------------------------------------------------
    print('\n[{:03d}/{:03d}] LOSS--------------   ACCURACY'.\
        format(epoch, config['NUM_EPOCHS']))
    print('[TRAIN]   {}   {} %'.format(train_loss, train_acc))
    print('[VAL]     {}   {} %'.format(val_loss, val_acc))

    # Adjust learning rate
    scheduler.step()
def train_classifier(model, data_loaders, args):
    """Train an emotion classifier."""
    # Setup
    device = args.device
    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    model, optimizer, _, start_epoch, is_trained = load_from_ckpnt(
        args.classifier_ckpnt, model, optimizer
    )
    scheduler = MultiStepLR(optimizer, [3, 6, 9], gamma=0.3,
                            last_epoch=start_epoch - 1)
    if is_trained:
        return model
    writer = SummaryWriter('runs/' + args.checkpoint.replace('.pt', ''))
    best_acc = -1

    # Training loop
    for epoch in range(start_epoch, args.epochs):
        print("Epoch: %d/%d" % (epoch + 1, args.epochs))
        kbar = pkbar.Kbar(target=len(data_loaders['train']), width=25)
        model.train()
        #model.enable_grads()
        for step, ex in enumerate(data_loaders['train']):
            images, _, emotions, _ = ex
            logits = model(images.to(device))
            labels = emotions.to(device)
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            kbar.update(step, [("loss", loss)])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            writer.add_scalar(
                'loss', loss.item(),
                epoch * len(data_loaders['train']) + step
            )
            break
        writer.add_scalar(
            'lr', optimizer.state_dict()['param_groups'][0]['lr'], epoch
        )
        # Evaluation and model storing
        if epoch % 2 == 0:
            print("\nValidation")
            acc = eval_classifier(model, data_loaders['test'], args, writer, epoch=epoch)
            writer.add_scalar('mAP', acc, epoch)
            if acc >= best_acc:
                torch.save(
                    {
                        "epoch": epoch + 1,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict()
                    },
                    args.classifier_ckpnt
                )
                best_acc = acc
            else:  # load checkpoint to update epoch
                checkpoint = torch.load(args.classifier_ckpnt)
                checkpoint["epoch"] += 1
                torch.save(checkpoint, args.classifier_ckpnt)
        scheduler.step()
    # Test
    test_acc = eval_classifier(model, data_loaders['test'], args, writer)
    print(f"Test Accuracy: {test_acc}")
    return model
예제 #20
0
                "Unbiased Accuracy : [f : {:.2f}, g : {:.2f}, v : {:.2f}, b : {:.2f}]"
                .format(
                    100 * f_acc_b /
                    (len(biased_test_loader) * args.batch_size), 100 *
                    g_acc_b / (len(biased_test_loader) * args.batch_size),
                    100 * v_acc_b /
                    (len(biased_test_loader) * args.batch_size), 100 *
                    b_acc_b / (len(biased_test_loader) * args.batch_size),
                    100 * f_acc_d /
                    (len(unbiased_test_loader) * args.batch_size), 100 *
                    g_acc_d / (len(unbiased_test_loader) * args.batch_size),
                    100 * v_acc_d /
                    (len(unbiased_test_loader) * args.batch_size), 100 *
                    b_acc_d / (len(unbiased_test_loader) * args.batch_size)))

            writer.add_scalar("loss/f", loss_f, epoch)
            writer.add_scalar("loss/g", loss_g, epoch)
            writer.add_scalar("loss/v", loss_v, epoch)
            writer.add_scalar("loss/b", loss_b, epoch)
            writer.add_scalar("loss/hsic", criterionHSIC(f_feats, g_feats),
                              epoch)
            writer.add_scalar(
                "accuracy/biased/f",
                100 * f_acc_b / (len(biased_test_loader) * args.batch_size),
                epoch)
            writer.add_scalar(
                "accuracy/biased/g",
                100 * g_acc_b / (len(biased_test_loader) * args.batch_size),
                epoch)
            writer.add_scalar(
                "accuracy/biased/v",
예제 #21
0
class Trainer:
    def __init__(self, config):
        self.config = config
        self.config['trainer']['output_dir'] = os.path.join(
            str(pathlib2.Path(os.path.abspath(__name__)).parent),
            self.config['trainer']['output_dir'])
        self.data_cfg = self.config["data_cfg"]
        self.dataset_name = self.data_cfg['name']
        self.method_name = "{0}_{1}".format(self.config['arch']['backbone'],
                                            self.dataset_name)
        self.save_dir = os.path.join(self.config['trainer']['output_dir'],
                                     self.method_name)
        self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
        if self.config['trainer']['resume_checkpoint'] == '' and self.config[
                'trainer']['finetune_checkpoint'] == '':
            shutil.rmtree(self.save_dir, ignore_errors=True)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

        self.global_step = 0
        self.start_epoch = 1

        self.tensorboard_enable = self.config['trainer']['tensorboard']
        self.epochs = self.config['trainer']['epochs']
        self.save_interval = self.config['trainer']['save_interval']
        self.show_images_interval = self.config['trainer'][
            'show_images_interval']
        self.display_interval = self.config['trainer']['display_interval']
        if self.tensorboard_enable:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(self.save_dir)

        # setup logger
        self.logger = setup_logger(os.path.join(self.save_dir, 'train_log'))
        self.logger.info(pformat(self.config))

        # device
        torch.manual_seed(self.config['trainer']['seed'])  # 为CPU设置随机种子
        if len(self.config['trainer']['gpus']) > 0 and torch.cuda.is_available(
        ):
            self.with_cuda = True
            torch.backends.cudnn.benchmark = True
            self.logger.info('Train with gpu {} & PyTorch {}'.format(
                self.config['trainer']['gpus'], torch.__version__))
            self.gpus = {
                i: item
                for i, item in enumerate(self.config['trainer']['gpus'])
            }
            self.device = torch.device("cuda:0")
            torch.cuda.manual_seed(
                self.config['trainer']['seed'])  # 为当前GPU设置随机种子
            torch.cuda.manual_seed_all(
                self.config['trainer']['seed'])  # 为所有GPU设置随机种子
        else:
            self.with_cuda = False
            self.logger.info('Train with cpu & PyTorch {}'.format(
                torch.__version__))
            self.device = torch.device("cpu")
        self.logger.info('Device {}'.format(self.device))

        # train data loader
        self.logger.info('Loading train data...')
        self.train_data_len = len(os.listdir(self.data_cfg["train_img_path"]))
        self.train_set = CustomDataSetRBox(
            self.data_cfg,
            max_img_length=self.config["trainer"]["input_size"],
            long_size=self.config['trainer']['long_size'])
        self.embedding_size = self.train_set.embedding_size
        self.words_embeddings = self.train_set.words_embeddings
        self.train_loader = data.DataLoader(
            self.train_set,
            batch_size=self.config["trainer"]["batch_size"],
            shuffle=True,
            num_workers=self.config["trainer"]["num_workers"],
            drop_last=False,
            pin_memory=True)
        self.train_loader_len = len(self.train_loader)
        self.logger.info('Train data has {0} samples, {1} in loader'.format(
            self.train_data_len, self.train_loader_len))

        # test data loader
        self.test_gt_path = self.train_set.test_gt_path
        self.test_img_files = self.train_set.test_img_files
        self.test_gt_files = self.train_set.test_gt_files
        self.test_words = self.train_set.test_words
        self.train_unique_words = self.train_set.train_unique_words
        self.label_encoder = LabelEncoder()

        # model
        self.logger.info('Loading model...')
        self.model = WordRetrievalModel(
            n_out=self.embedding_size,
            backbone=self.config["arch"]["backbone"],
            pre_trained=self.config["arch"]["pre_trained"])

        # loss function
        self.logger.info('Loading loss function...')
        self.criterion = ModelLoss(
            weight_cls=self.config["loss"]["weight_cls"],
            weight_angle=self.config["loss"]["weight_angle"],
            weight_diou=self.config["loss"]["weight_diou"],
            weight_embed=self.config["loss"]["weight_embed"])

        # optimizer and lr_scheduler
        self.logger.info('Loading optimizer and lr_scheduler...')
        self.lr = self.config["optimizer"]['args']['lr']
        self.lr_step = self.config["trainer"]["lr_step"]
        self.optimizer = self._initialize('optimizer', torch.optim,
                                          self.model.parameters())
        self.scheduler = self._initialize('lr_scheduler',
                                          torch.optim.lr_scheduler,
                                          self.optimizer)
        if self.config['trainer']['resume_checkpoint'] != '':
            self._load_checkpoint(self.config['trainer']['resume_checkpoint'],
                                  resume=True)
        elif self.config['trainer']['finetune_checkpoint'] != '':
            self._load_checkpoint(
                self.config['trainer']['finetune_checkpoint'], resume=False)

        # eval args
        self.cls_score_thresh = self.config['tester']['cls_score_thresh']
        self.bbox_nms_overlap = self.config['tester']['bbox_nms_overlap']
        self.query_nms_overlap = self.config['tester']['query_nms_overlap']
        self.overlap_thresh = 0.25
        self.metric = self.config['tester']['distance_metric']

        # 单机多卡
        num_gpus = torch.cuda.device_count()
        if num_gpus > 1:
            self.model = nn.DataParallel(self.model)
        self.model.to(self.device)

        self.metrics = {
            'precision': 0,
            'recall': 0,
            'hmean': 0,
            'map': 0,
            'mr': 0,
            'train_loss': float('inf'),
            'best_model': ''
        }

    def train(self):
        """ Full training logic """
        self.logger.info('Start training...')
        for epoch in range(self.start_epoch, self.epochs + 1):
            try:
                self.adjust_learning_rate(epoch)
                self.epoch_result = self._train_epoch(epoch)
                self._on_epoch_finish(epoch)
            except torch.cuda.CudaError:
                self._log_memory_usage()
        if self.tensorboard_enable:
            self.writer.close()
        self._on_train_finish()

    def _train_epoch(self, epoch):
        """ Training logic for an epoch """
        self.model.train()
        epoch_start, batch_start = time.time(), time.time()
        train_loss = 0.0
        lr = self.optimizer.param_groups[0]['lr']
        for i, (img, gt_score, gt_geo, ignored_map,
                gt_embedding) in enumerate(self.train_loader):
            if i >= self.train_loader_len:
                break
            self.global_step += 1
            lr = self.optimizer.param_groups[0]['lr']

            cur_batch_size = img.size()[0]
            img, gt_score, gt_geo, ignored_map, gt_embedding = img.to(self.device), gt_score.to(self.device), \
                                                               gt_geo.to(self.device), ignored_map.to(self.device), \
                                                               gt_embedding.to(self.device)

            (predict_score, predict_geo), predict_embedding = self.model(img)
            loss_all, loss_cls, loss_ang, loss_diou, loss_embed = self.criterion(
                gt_score, predict_score, gt_geo, predict_geo, gt_embedding,
                predict_embedding, ignored_map)

            # backward
            self.optimizer.zero_grad()
            loss_all.backward()
            self.optimizer.step()

            loss_all = loss_all.item()
            loss_cls, loss_ang, loss_diou = loss_cls.item(), loss_ang.item(
            ), loss_diou.item()
            loss_embed = loss_embed.item()
            train_loss += loss_all

            if i % self.display_interval == 0 or i == self.train_loader_len - 1:
                batch_time = time.time() - batch_start
                self.logger.info(
                    '[{}/{}], [{}/{}], g_step: {}, Spe: {:.1f} sam/sec, l_all: {:.4f}, l_cls: {:.4f}, '
                    'l_ang: {:.4f}, l_diou: {:.4f}, l_embed: {:.4f}, lr: {:.6}, T: {:.2f}'
                    .format(
                        str(epoch).zfill(3), self.epochs,
                        str(i + 1).zfill(3), self.train_loader_len,
                        self.global_step,
                        self.display_interval * cur_batch_size / batch_time,
                        loss_all, loss_cls, loss_ang, loss_diou, loss_embed,
                        lr, batch_time))
                batch_start = time.time()

            if self.tensorboard_enable:
                self.writer.add_scalar('TRAIN/LOSS/loss_all', loss_all,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/LOSS/loss_cls', loss_cls,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/LOSS/loss_ang', loss_ang,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/LOSS/loss_diou', loss_diou,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/LOSS/loss_embed', loss_embed,
                                       self.global_step)
                self.writer.add_scalar('TRAIN/lr', lr, self.global_step)

        return {
            'train_loss': train_loss / self.train_loader_len,
            'lr': lr,
            'time': time.time() - epoch_start,
            'epoch': epoch
        }

    def _eval_map(self):
        self.logger.info('Enter evaluating...')
        self.model.eval()
        result_save_path = os.path.join(self.save_dir, 'result')
        if os.path.exists(result_save_path):
            shutil.rmtree(result_save_path, ignore_errors=True)
        if not os.path.exists(result_save_path):
            os.makedirs(result_save_path)

        predict_embeddings, joint_boxes, all_gt_boxes = [], [], []
        qbs_words, qbs_queries, qbs_targets, db_targets, gt_targets = [], [], [], [], []
        overlaps, used_test_word = [], []

        # Compute a mapping from class string to class id...
        self.label_encoder.fit([word for word in self.test_words])

        # Create queries...
        test_unique_words, counts = np.unique(self.test_words,
                                              return_counts=True)
        for idx, test_word in enumerate(self.test_words):
            gt_targets.extend(self.label_encoder.transform([test_word]))
            if test_word not in used_test_word and test_word in test_unique_words:
                qbs_words.append(test_word)
                qbs_queries.append(self.words_embeddings[test_word])
                qbs_targets.extend(self.label_encoder.transform([test_word]))
                used_test_word.append(test_word)

        for i, (img_file, gt_file) in enumerate(
                zip(self.test_img_files, self.test_gt_files)):
            self.logger.info('Evaluating {} image: {}'.format(i, img_file))
            # Get test gt boxes & gt words...
            gt_boxes, gt_words = [], []
            with open(gt_file, mode='r', encoding='utf-8') as f:
                lines = f.readlines()
            for line in lines:
                line = line.strip().rstrip('\n').lstrip(
                    '\ufeff').strip().split(',', maxsplit=8)
                gt_boxes.append([int(ver) for ver in line[:8]])
                gt_words.append(str(line[-1]).strip().lower())

            # Get img...
            im = Image.open(img_file)
            im = im.convert("RGB")
            im, ratio_w, ratio_h = resize_img(
                im, long_size=self.config['trainer']['long_size'])
            with torch.no_grad():
                if str(self.device).__contains__('cuda'):
                    torch.cuda.synchronize(self.device)
                transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                         std=(0.5, 0.5, 0.5))
                ])
                im = transform(im).unsqueeze(0)
                im = im.to(self.device)
                (predict_score, predict_geo), predict_embed = self.model(im)
                if str(self.device).__contains__('cuda'):
                    torch.cuda.synchronize(self.device)
            # Predicting boxes...
            predict_boxes, _ = get_boxes(
                score=predict_score.squeeze(0).cpu().numpy(),
                geo=predict_geo.squeeze(0).cpu().numpy(),
                cls_score_thresh=self.cls_score_thresh,
                bbox_nms_overlap=self.bbox_nms_overlap)
            predict_embed = predict_embed.squeeze(0).cpu().numpy()

            if predict_boxes is None:
                continue
            self.logger.info(
                'Idx: {0} ===> Predict result [predict_boxes: {1}; gt_boxes: {2}]'
                .format(i, predict_boxes.shape, len(gt_boxes)))
            for predict_box in predict_boxes:
                min_x = min(predict_box[0], predict_box[2], predict_box[4],
                            predict_box[6])
                max_x = max(predict_box[0], predict_box[2], predict_box[4],
                            predict_box[6])
                min_y = min(predict_box[1], predict_box[3], predict_box[5],
                            predict_box[7])
                max_y = max(predict_box[1], predict_box[3], predict_box[5],
                            predict_box[7])
                w, h = max_x - min_x, max_y - min_y
                differ = h * 0.2 if h < w else w * 0.2
                min_x, max_x = int((min_x + differ) / 4), int(
                    (max_x - differ) / 4)
                min_y, max_y = int((min_y + differ) / 4), int(
                    (max_y - differ) / 4)
                if min_x > max_x or min_y > max_y:
                    continue
                predict_embeddings.append(
                    np.mean(predict_embed[:, min_y:max_y, min_x:max_x],
                            axis=(1, 2)))

            predict_boxes = adjust_ratio(predict_boxes, ratio_w, ratio_h)
            seq = []
            if predict_boxes is not None:
                seq.extend([
                    ','.join([str(int(b)) for b in box[:-1]]) + '\n'
                    for box in predict_boxes
                ])
            with open(
                    os.path.join(
                        result_save_path,
                        str(os.path.basename(img_file).split('.')[0]) +
                        '.txt'), 'w') as f:
                f.writelines(seq)

            joint_boxes.extend(predict_boxes[:, :8])
            all_gt_boxes.extend(gt_boxes)
            gt_boxes = np.array(gt_boxes)
            # Calculate overlap...
            overlap = cal_overlap(predict_boxes, gt_boxes)
            overlaps.append(overlap)
            inds = overlap.argmax(axis=1)
            db_targets.extend(
                self.label_encoder.transform([gt_words[idx] for idx in inds]))

        # End evaluate...
        db = np.vstack(predict_embeddings) if len(
            predict_embeddings) != 0 else np.array(predict_embeddings)
        all_overlaps = np.zeros((len(joint_boxes), len(all_gt_boxes)),
                                dtype=np.float32)
        x, y = 0, 0
        for o in overlaps:
            all_overlaps[y:y + o.shape[0], x:x + o.shape[1]] = o
            y += o.shape[0]
            x += o.shape[1]
        db_targets, qbs_targets, qbs_words = np.array(db_targets), np.array(
            qbs_targets), np.array(qbs_words)
        qbs_queries, joint_boxes = np.array(qbs_queries), np.array(joint_boxes)

        assert (qbs_queries.shape[0] == qbs_targets.shape[0])
        assert (db.shape[0] == db_targets.shape[0])

        self.logger.info('Calculate mAP...')
        mAP_qbs, mR_qbs = cal_map(qbs_queries,
                                  qbs_targets,
                                  db,
                                  db_targets,
                                  gt_targets,
                                  joint_boxes,
                                  all_overlaps,
                                  self.query_nms_overlap,
                                  self.overlap_thresh,
                                  qbs_words,
                                  num_workers=0)
        mAP_qbs, mR_qbs = np.mean(mAP_qbs * 100), np.mean(mR_qbs * 100)

        # Calculate recall precision f1
        res_dict = cal_recall_precison_f1(gt_path=self.test_gt_path,
                                          result_path=result_save_path)
        return res_dict['recall'], res_dict['precision'], res_dict[
            'hmean'], mAP_qbs, mR_qbs

    def _on_epoch_finish(self, epoch):
        # torch.cuda.empty_cache()
        self.logger.info(
            '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                self.epoch_result['epoch'], self.epochs,
                self.epoch_result['train_loss'], self.epoch_result['time'],
                self.epoch_result['lr']))

        if epoch % self.save_interval == 0:
            net_save_path = '{}/WordRetrievalNet_latest.pth'.format(
                self.checkpoint_dir)

            save_best = False
            if self.config['trainer']['metrics'] == 'map':  # 使用map作为最优模型指标
                recall, precision, hmean, mAP_qbs, mR_qbs = self._eval_map()

                if self.tensorboard_enable:
                    self.writer.add_scalar('EVAL/precision', precision,
                                           self.global_step)
                    self.writer.add_scalar('EVAL/recall', recall,
                                           self.global_step)
                    self.writer.add_scalar('EVAL/hmean', hmean,
                                           self.global_step)
                    self.writer.add_scalar('EVAL/mAP', mAP_qbs,
                                           self.global_step)
                    self.writer.add_scalar('EVAL/mR', mR_qbs, self.global_step)
                self.logger.info(
                    'test: precision: {:.6f}, recall: {:.6f}, f1: {:.6f}, map: {:.2f}, mr: {:.2f}'
                    .format(precision, recall, hmean, mAP_qbs, mR_qbs))

                if mAP_qbs > self.metrics['map']:
                    save_best = True
                    self.metrics['train_loss'], self.metrics[
                        'best_model'] = self.epoch_result[
                            'train_loss'], net_save_path
                    self.metrics['precision'], self.metrics[
                        'recall'], self.metrics[
                            'hmean'] = precision, recall, hmean
                    self.metrics['map'], self.metrics['mr'] = mAP_qbs, mR_qbs
            else:
                if self.epoch_result['train_loss'] < self.metrics['train_loss']:
                    save_best = True
                    self.metrics['train_loss'], self.metrics[
                        'best_model'] = self.epoch_result[
                            'train_loss'], net_save_path
            self._save_checkpoint(self.epoch_result['epoch'], net_save_path,
                                  save_best)

    def _on_train_finish(self):
        for k, v in self.metrics.items():
            self.logger.info('{}:{}'.format(k, v))
        self.logger.info('Finish train.')

    def _log_memory_usage(self):
        if not self.with_cuda:
            return
        usage = []
        for deviceID, device in self.gpus.items():
            allocated = torch.cuda.memory_allocated(
                int(deviceID)) / (1024 * 1024)
            cached = torch.cuda.memory_cached(int(deviceID)) / (1024 * 1024)
            usage.append(
                '    CUDA: {0}; Allocated: {1} MB; Cached: {2} MB \n'.format(
                    device, allocated, cached))
        self.logger.debug("Memory Usage: \n{}".format(''.join(usage)))

    def _save_checkpoint(self, epoch, file_name, save_best=False):
        """ Saving checkpoints """
        state_dict = {
            'epoch': epoch,
            'global_step': self.global_step,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'config': self.config,
            'metrics': self.metrics,
        }
        filename = os.path.join(self.checkpoint_dir, file_name)
        torch.save(state_dict, filename)
        if save_best:
            shutil.copy(
                filename,
                os.path.join(self.checkpoint_dir, 'WordRetrievalNet_best.pth'))
            self.logger.info("Saving current best: {}".format(file_name))
        else:
            self.logger.info("Saving checkpoint: {}".format(filename))

    def _load_checkpoint(self, checkpoint_path, resume):
        """ Resume from saved checkpoints """
        self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict({
            k.replace('module.', ''): v
            for k, v in checkpoint['state_dict'].items()
        })
        if resume:
            self.global_step = checkpoint['global_step']
            self.start_epoch = checkpoint['epoch'] + 1
            self.config['lr_scheduler']['args'][
                'last_epoch'] = self.start_epoch
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            if 'metrics' in checkpoint:
                self.metrics = checkpoint['metrics']
            if self.with_cuda:
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(self.device)
            self.logger.info("Resume from checkpoint {} (epoch {})".format(
                checkpoint_path, self.start_epoch))
        else:
            self.logger.info(
                "FineTune from checkpoint {}".format(checkpoint_path))

    def _initialize(self, name, module, *args, **kwargs):
        module_name = self.config[name]['type']
        module_args = self.config[name]['args']
        assert all([
            k not in module_args for k in kwargs
        ]), 'Overwriting kwargs given in config file is not allowed'
        module_args.update(kwargs)
        return getattr(module, module_name)(*args, **module_args)

    def adjust_learning_rate(self, epoch):
        if epoch in self.lr_step:
            self.lr = self.lr * 0.1
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr
예제 #22
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'bert', 'm2transformer'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in ['train', 'val', 'test']
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                    att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag,
                                    struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            if opt.grad_clip_value != 0:
                getattr(torch.nn.utils, 'clip_grad_%s_' %
                        (opt.grad_clip_mode))(model.parameters(),
                                              opt.grad_clip_value)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar(
                        'lm_loss', model_out['lm_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'struc_loss', model_out['struc_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'reward', model_out['reward'].mean().item(), iteration)
                    tb_summary_writer.add_scalar(
                        'reward_var', model_out['reward'].var(1).mean(),
                        iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch) or \
                (epoch_done and opt.save_every_epoch):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(
                        opt,
                        model,
                        infos,
                        optimizer,
                        append=str(epoch)
                        if opt.save_every_epoch else str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
예제 #23
0
파일: policies.py 프로젝트: lyltc1/trc-nn
def train_mine_policy(scenario: Scenario,
                      horizon: int,
                      batch_size: int,
                      epochs: int,
                      ntrvs: int,
                      mine_class: nn.Module,
                      mine_params,
                      q_net: nn.Module,
                      pi_net: nn.Module,
                      tradeoff: float,
                      lr: float,
                      tag: str = None,
                      save_every: int = 100,
                      log_video_every: Union[int, None] = None,
                      minibatch_size=0,
                      opt_iters=1,
                      lowest_mi=np.inf,
                      cutoff=np.inf,
                      device=pt.device('cpu')):

    q_net.to(device=device)
    pi_net.to(device=device)
    opt = pt.optim.Adam(list(pi_net.parameters()) + list(q_net.parameters()),
                        lr=lr)
    mine = [mine_class().to(device=device) for t in range(horizon)]
    last_time = time.time()
    mi = pt.zeros(horizon).to(device=device)

    scenario.device = pt.device('cpu')

    prev_best_value = np.inf
    current_value = np.inf

    if minibatch_size == 0:
        minibatch_size = batch_size

    if tag is not None:
        writer = SummaryWriter(f'runs/{tag}', flush_secs=1)

    for epoch in range(epochs):
        #if epoch % save_every == 0 or epoch == epochs - 1:
        start_epoch_event = pt.cuda.Event(enable_timing=True)
        end_epoch_event = pt.cuda.Event(enable_timing=True)
        end_rollout_event = pt.cuda.Event(enable_timing=True)

        start_epoch_event.record()

        pi_log_probs = pt.zeros((horizon, minibatch_size), device=device)
        q_log_probs = pt.zeros((horizon, minibatch_size), device=device)

        q_net.cpu()
        pi_net.cpu()

        states, outputs, samples, trvs, inputs, costs = rollout(
            pi_net, q_net, ntrvs, scenario, horizon, batch_size,
            pt.device('cpu'))
        end_rollout_event.record()
        pt.cuda.synchronize()
        elapsed_rollout_time = start_epoch_event.elapsed_time(
            end_rollout_event) / 1000

        print(f'Rollout Time: {elapsed_rollout_time:.3f}')
        print(
            f'Mean Abs. Displacement: {pt.abs(states[0, -1, :] - states[1, -1, :]).mean().detach().item()}'
        )

        states = states.to(device)
        outputs = outputs.to(device)
        samples = samples.to(device)
        trvs = trvs.to(device)
        inputs = inputs.to(device)
        costs = costs.to(device)

        q_net.to(device)
        pi_net.to(device)

        for s in range(batch_size):
            trv = pt.zeros(ntrvs, device=device)

            for t in range(horizon):
                trvs[:, t, s] = q_net(outputs[:, t, s], trv, t, samples[:, t,
                                                                        s])[0]
                trv = trvs[:, t, s]

        value = costs.sum(axis=0).mean().item()

        if tradeoff > -1:
            states_mi = states.detach().cuda()
            trvs_mi = trvs.detach().cuda()

            for t in range(horizon):
                mine[t].cuda()
                if epoch == 0:
                    values = train_mine_network(
                        mine[t], (states_mi[:, t, :], trvs_mi[:, t, :]),
                        epochs=100 * mine_params['epochs'])
                else:
                    train_mine_network(mine[t],
                                       (states_mi[:, t, :], trvs_mi[:, t, :]),
                                       epochs=mine_params['epochs'])

            for t in range(horizon):
                num_datapts = states.shape[2]
                batch_size = num_datapts

                joint_batch_idx = np.random.choice(range(num_datapts),
                                                   size=num_datapts,
                                                   replace=False)
                marginal_batch_idx1 = np.random.choice(range(num_datapts),
                                                       size=num_datapts,
                                                       replace=False)
                marginal_batch_idx2 = np.random.choice(range(num_datapts),
                                                       size=num_datapts,
                                                       replace=False)

                joint_batch = pt.cat(
                    (states[:, t, joint_batch_idx], trvs[:, t,
                                                         joint_batch_idx]),
                    axis=0).t()
                marginal_batch = pt.cat((states[:, t, marginal_batch_idx1],
                                         trvs[:, t, marginal_batch_idx2]),
                                        axis=0).t()

                j_T = mine[t](joint_batch)
                m_T = mine[t](marginal_batch)

                mi[t] = j_T.mean() - pt.log(pt.mean(pt.exp(m_T)))

        mi_sum = mi.sum()
        baseline = costs.sum(axis=0).mean()

        current_value = value + tradeoff * mi_sum.detach()

        if value < cutoff and mi_sum < lowest_mi:
            print('Saving Model...')
            lowest_mi = mi_sum.item()
            pt.save(
                {
                    'pi_net_state_dict': pi_net.state_dict(),
                    'q_net_state_dict': q_net.state_dict()
                }, f'models/{tag}_epoch_{epoch}_mi_{lowest_mi:.3f}')
        else:
            print(f'Current Best: {prev_best_value}')

        for iter in range(opt_iters):
            print(f'Computing Iteration {iter}')
            minibatch_idx = np.random.choice(range(batch_size),
                                             size=minibatch_size,
                                             replace=False)

            outputs_minibatch = outputs[:, :, minibatch_idx]
            trvs_minibatch = trvs[:, :, minibatch_idx]
            inputs_minibatch = inputs[:, :, minibatch_idx]
            costs_minibatch = costs[:, minibatch_idx]

            for s in range(minibatch_size):
                trv = pt.zeros(ntrvs, device=device)

                for t in range(horizon):
                    q_log_probs[t,
                                s] = q_net.log_prob(trvs[:, t, s].detach(),
                                                    outputs_minibatch[:, t, s],
                                                    trv.detach(), t)
                    pi_log_probs[t, s] = pi_net.log_prob(
                        inputs_minibatch[:, t, s].detach(),
                        trvs_minibatch[:, t, s].detach(), t)
                    trv = trvs_minibatch[:, t, s]

            opt.zero_grad()
            loss = pt.mul(pi_log_probs.sum(axis=0), costs_minibatch.sum(axis=0) - baseline).mean() + \
                   pt.mul(q_log_probs.sum(axis=0), costs_minibatch.sum(axis=0) - baseline).mean() + \
                   tradeoff * mi_sum
            loss.backward()
            opt.step()

            pi_log_probs = pi_log_probs.detach()
            q_log_probs = pi_log_probs.detach()

        if tag is not None:
            writer.add_scalar('Loss/Total', value + tradeoff * mi.sum().item(),
                              epoch)
            writer.add_scalar('Loss/MI', mi_sum, epoch)
            writer.add_scalar('Loss/Cost', value, epoch)
            writer.add_histogram('Loss/Cost Dist', costs.sum(axis=0), epoch)

            if log_video_every is not None and epoch % log_video_every == 0:
                print('Saving Video...')

                best_traj_idx = pt.argmin(costs.sum(axis=0))
                worst_traj_idx = pt.argmax(costs.sum(axis=0))

                best_traj_vid = pt.stack([
                    pt.stack([
                        outputs[:, t, best_traj_idx].view(3, 64, 64)
                        for t in range(horizon)
                    ])
                ])
                worst_traj_vid = pt.stack([
                    pt.stack([
                        outputs[:, t, worst_traj_idx].view(3, 64, 64)
                        for t in range(horizon)
                    ])
                ])

                writer.add_video('Loss/Worst Traj', worst_traj_vid, epoch)
                writer.add_video('Loss/Best Traj', best_traj_vid, epoch)

        mi = mi.detach()
        end_epoch_event.record()
        pt.cuda.synchronize()
        elapsed_epoch_time = start_epoch_event.elapsed_time(
            end_epoch_event) / 1000

        print(
            f'[{tradeoff}.{epoch}: {elapsed_epoch_time:.3f}]\t\tAvg. Cost: {value:.3f}\t\tEst. MI: {mi_sum.item():.5f}\t\tTotal: {value + tradeoff * mi_sum.item():.3f}\t\t Lowest MI: {lowest_mi:.3f}'
        )

        if epoch == epochs - 1:
            return lowest_mi