def __init__(self, hparams, model=None, vocab_word=None, vocab_role=None,
                 vocab_pos=None, checkpoint=None, summary_writer=None):
        super(Predictor, self).__init__()
        self.hparams = hparams
        self.model = model
        self.vocab_word = vocab_word
        self.vocab_role = vocab_role
        self.vocab_pos = vocab_pos
        self.device = hparams.device
        self.batch_size = hparams.batch_size

        # Beam-search configuration
        self.min_length = hparams.min_length
        self.gen_max_length = hparams.gen_max_length
        self.beam_size = hparams.beam_size
        self.start_token_id = self.vocab_word.token2id['<BEGIN>']
        self.end_token_id = self.vocab_word.token2id['<END>']

        self.device = hparams.device

        self.summary_writer = summary_writer

        if (model == None) and (checkpoint != ''):
            self.build_model()

            if self.vocab_word is None:
                self.vocab_word = load_vocab(self.hparams.vocab_word_path)

            model_state_dict, optimizer_state_dict = load_checkpoint(self.hparams.load_pthpath)

            print('============= Loading Trained Model from: ', self.hparams.load_pthpath, ' ==================')
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(model_state_dict)
            else:
                self.model.load_state_dict(model_state_dict, strict=True)
Ejemplo n.º 2
0
    def setup_training(self):
        self.save_dirpath = self.hparams.save_dirpath
        today = str(datetime.today().month) + 'M_' + str(
            datetime.today().day) + 'D'
        tensorboard_path = self.save_dirpath + today
        self.summary_writer = SummaryWriter(tensorboard_path, comment="Unmt")
        self.checkpoint_manager = CheckpointManager(self.model,
                                                    self.optimizer,
                                                    self.save_dirpath,
                                                    hparams=self.hparams)

        # If loading from checkpoint, adjust start epoch and load parameters.
        if self.hparams.load_pthpath == "":
            self.start_epoch = 1
        else:
            # "path/to/checkpoint_xx.pth" -> xx
            self.start_epoch = int(
                self.hparams.load_pthpath.split("_")[-1][:-4])
            self.start_epoch += 1
            model_state_dict, optimizer_state_dict = load_checkpoint(
                self.hparams.load_pthpath)
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(model_state_dict,
                                                  strict=True)
            else:
                self.model.load_state_dict(model_state_dict)

            self.optimizer.load_state_dict(optimizer_state_dict, strict=True)
            self.previous_model_path = self.hparams.load_pthpath
            print("Loaded model from {}".format(self.hparams.load_pthpath))

        print("""
            # -------------------------------------------------------------------------
            #   Setup Training Finished
            # -------------------------------------------------------------------------
            """)
Ejemplo n.º 3
0
def train():

    #################################################
    # Argparse stuff click was a bad idea after all #
    #################################################

    parser = argparse.ArgumentParser()

    parser.add_argument('--config-json',
                        type=str,
                        help="The json file specifying the args below")

    parser.add_argument('--embedder-path',
                        type=str,
                        help="Path to the embedder checkpoint." +
                        " Example: 'embedder/data/best_model'")

    group_chk = parser.add_argument_group('checkpointing')
    group_chk.add_argument('--epoch-save-interval',
                           type=int,
                           help="After every [x] epochs save w/" +
                           "checkpoint manager")
    group_chk.add_argument('--save-dir',
                           type=str,
                           help="Relative path of save directory, " +
                           "include the trailing /")
    group_chk.add_argument("--load-dir",
                           type=str,
                           help="Checkpoint prefix directory to " +
                           "load initial model from")

    group_system = parser.add_argument_group('system')
    group_system.add_argument('--cpu-workers',
                              type=int,
                              help="Number of CPU workers for dataloader")
    group_system.add_argument('--torch-seed',
                              type=int,
                              help="Seed for for torch and torch_cudnn")
    group_system.add_argument('--gpu-ids',
                              help="The GPU ID to use. If -1, use CPU")

    group_data = parser.add_argument_group('data')
    group_data.add_argument('--mel-size',
                            type=int,
                            help="Number of channels in the mel-gram")
    group_data.add_argument('--style-size',
                            type=int,
                            help="Dimensionality of style vector")
    group_data.add_argument('--dset-num-people',
                            type=int,
                            help="If using VCTK, an integer under 150")
    group_data.add_argument('--dset-num-samples',
                            type=int,
                            help="If using VCTK, an integer under 300")
    group_data.add_argument('--mel-root',
                            default='data/taco/',
                            type=str,
                            help='Path to the directory (include last /) ' +
                            'where the person mel folders are')

    group_training = parser.add_argument_group('training')
    group_training.add_argument('--num-epochs',
                                type=int,
                                help="The number of epochs to train for")
    group_training.add_argument(
        '--lr-dtor-isvoice',
        type=float,
    )
    group_training.add_argument(
        '--lr-tform',
        type=float,
    )

    group_training.add_argument(
        '--num-batches-dtor-isvoice',
        type=int,
    )
    group_training.add_argument(
        '--batch-size-dtor-isvoice',
        type=int,
    )

    group_training.add_argument(
        '--num-batches-tform',
        type=int,
    )
    group_training.add_argument(
        '--batch-size-tform',
        type=int,
    )

    group_model = parser.add_argument_group('model')
    group_model.add_argument('--identity-mode', help='One of [norm, cos, nn]')

    args = parser.parse_args()
    if args.config_json is not None:
        with open(args.config_json) as json_file:
            file_args = json.load(json_file)
        cli_dict = vars(args)
        for key in cli_dict:
            if cli_dict[key] is not None:
                file_args[key] = cli_dict[key]
        args.__dict__ = file_args

    print("CLI args are: ", args)
    with open("configs/basic.yml") as f:
        config = yaml.full_load(f)

    ############################
    # Setting up the constants #
    ############################

    if args.save_dir is not None and args.save_dir[-1] != "/":
        args.save_dir += "/"
    if args.load_dir is not None and args.load_dir[-1] != "/":
        args.load_dir += "/"

    SAVE_DTOR_ISVOICE = args.save_dir + FOLDER_DTOR_IV
    SAVE_TRANSFORMER = args.save_dir + FOLDER_TRANSFORMER

    ############################
    # Reproducibility Settings #
    ############################
    # Refer to https://pytorch.org/docs/stable/notes/randomness.html
    torch.manual_seed(args.torch_seed)
    torch.cuda.manual_seed_all(args.torch_seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # TODO Enable?
    # torch.set_default_tensor_type(torch.cuda.FloatTensor)

    #############################
    # Setting up Pytorch device #
    #############################
    use_cpu = -1 == args.gpu_ids
    device = torch.device("cpu" if use_cpu else "cuda")

    ###############################################
    # Initialize the model and related optimizers #
    ###############################################

    if args.load_dir is None:
        start_epoch = 0
    else:
        start_epoch = int(args.load_dir.split("_")[-1][:-4])

    model = ProjectModel(config=config["transformer"],
                         embedder_path=args.embedder_path,
                         mel_size=args.mel_size,
                         style_size=args.style_size,
                         identity_mode=args.identity_mode,
                         cuda=(not use_cpu))
    model = model.to(device)
    tform_optimizer = torch.optim.Adam(model.transformer.parameters(),
                                       lr=args.lr_tform)
    tform_checkpointer = CheckpointManager(model.transformer, tform_optimizer,
                                           SAVE_TRANSFORMER,
                                           args.epoch_save_interval,
                                           start_epoch + 1)

    dtor_isvoice_optimizer = torch.optim.Adam(model.isvoice_dtor.parameters(),
                                              lr=args.lr_dtor_isvoice)
    dtor_isvoice_checkpointer = CheckpointManager(model.isvoice_dtor,
                                                  dtor_isvoice_optimizer,
                                                  SAVE_DTOR_ISVOICE,
                                                  args.epoch_save_interval,
                                                  start_epoch + 1)

    ###############################################

    # Load the checkpoint, if it is specified
    if args.load_dir is not None:
        tform_md, tform_od = load_checkpoint(SAVE_TRANSFORMER)
        model.transformer.load_state_dict(tform_md)
        tform_optimizer.load_state_dict(tform_od)

        dtor_isvoice_md, dtor_isvoice_od = load_checkpoint(SAVE_DTOR_ISVOICE)
        model.dtor_isvoice.load_state_dict(dtor_isvoice_md)
        tform_optimizer.load_state_dict(dtor_isvoice_od)

    ##########################
    # Declaring the datasets #
    ##########################

    dset_wrapper = VCTK_Wrapper(
        model.embedder,
        args.dset_num_people,
        args.dset_num_samples,
        args.mel_root,
        device,
    )

    if args.mel_size != dset_wrapper.mel_from_ids(0, 0).size()[-1]:
        raise RuntimeError("mel size arg is different from that in file")

    dset_isvoice_real = Isvoice_Dataset_Real(dset_wrapper, )
    dset_isvoice_fake = Isvoice_Dataset_Fake(dset_wrapper, model.embedder,
                                             model.transformer)
    dset_generator_train = Generator_Dataset(dset_wrapper, )
    # We're enforcing identity via a resnet connection for now, so unused
    # dset_identity_real = Identity_Dataset_Real(dset_wrapper,
    #                                            embedder)
    # dset_identity_fake = Identity_Dataset_Fake(dset_wrapper,
    #                                            embedder, transformer)

    collate_along_timeaxis = lambda x: collate_pad_tensors(x, pad_dim=1)
    dload_isvoice_real = DataLoader(dset_isvoice_real,
                                    batch_size=args.batch_size_dtor_isvoice,
                                    collate_fn=collate_along_timeaxis)
    dload_isvoice_fake = DataLoader(dset_isvoice_fake,
                                    batch_size=args.batch_size_dtor_isvoice,
                                    collate_fn=collate_along_timeaxis)
    dload_generator = DataLoader(dset_generator_train,
                                 batch_size=args.batch_size_tform,
                                 collate_fn=Generator_Dataset.collate_fn)

    #######################################################
    # The actual training loop gaaah what a rollercoaster #
    #######################################################
    train_start_time = datetime.now()
    print("Started Training at {}".format(train_start_time))
    for epoch in range(args.num_epochs):
        epoch_start_time = datetime.now()
        ###############
        # (D1) Train Real vs Fake Discriminator
        ###############
        train_dtor(model.isvoice_dtor, dtor_isvoice_optimizer,
                   dload_isvoice_real, dload_isvoice_fake,
                   args.num_batches_dtor_isvoice, device)
        dtor_isvoice_checkpointer.step()
        gc.collect()

        # Train generators here
        ################
        # (G) Update Generator
        ################
        val_loss = train_gen(model,
                             tform_optimizer,
                             dload_generator,
                             device,
                             num_batches=args.num_batches_tform)
        tform_checkpointer.step()
        gc.collect()
Ejemplo n.º 4
0
if __name__ == '__main__':
    args = parser.parse_args()

    with open(args.config_json) as json_file:
        file_args = json.load(json_file)

    with open("configs/basic.yml") as yaml_file:
        config = yaml.full_load(yaml_file)

    os.makedirs(args.output_path, exist_ok=True)

    device = torch.device("cuda") if args.use_gpu else torch.device("cpu")

    model = ProjectModel(config=config["transformer"],
                         embedder_path=file_args["embedder_path"],
                         mel_size=file_args["mel_size"],
                         style_size=file_args["style_size"],
                         identity_mode=file_args["identity_mode"],
                         cuda=args.use_gpu)

    model = model.to(device)
    tform_md, tform_od = load_checkpoint(args.transformer_path)
    model.transformer.load_state_dict(tform_md)

    mel = generate(model, args,
                   device).to(torch.device("cpu")).squeeze().permute(1, 0)
    torch.save(
        mel,
        Path(args.output_path,
             args.content_mel.stem + "_" + args.style_mel.stem + ".pt"))
Ejemplo n.º 5
0
    print("{:<20}: {}".format(arg, getattr(args, arg)))

# ================================================================================================
#   SETUP DATASET, DATALOADER, MODEL
# ================================================================================================

dataset = SoundDataset(config["dataset"]["source_dir"])
dataloader = DataLoader(dataset,
                        batch_size=config["solver"]["batch_size"],
                        num_workers=args.cpu_workers)

model = None
if -1 not in args.gpu_ids:
    model = nn.DataParallel(model, args.gpu_ids)

model_state_dict, _ = load_checkpoint(args.load_pthpath)
if isinstance(model, nn.DataParallel):
    model.module.load_state_dict(model_state_dict)
else:
    model.load_state_dict(model_state_dict)
print("Loaded model from {}".format(args.load_pthpath))

# ================================================================================================
#   EVALUATION LOOP
# ================================================================================================

# Note that since our evaluation is qualitative, we may consider just generating audio
model.eval()

for i, batch in enumerate(tqdm(dataloader)):
    for key in batch:
Ejemplo n.º 6
0
optimizer = optim.Adam(model_conv.parameters(), lr=args.lr)
exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                  mode='min',
                                                  patience=5,
                                                  factor=0.1)

es = EarlyStopping(patience=args.es_patience)

if args.debug:
    n_steps = 50
else:
    n_steps = args.epochs * (len(train_data) // batch_size)

if has_checkpoint():
    state = load_checkpoint()
    model_conv.load_state_dict(state['model_dict'])
    optimizer.load_state_dict(state['optimizer_dict'])
    exp_lr_scheduler.load_state_dict(state['scheduler_dict'])
    train_loader.sampler.load_state_dict(state['sampler_dict'])
    start_step = state['start_step']
    es = state['es']
    torch.random.set_rng_state(state['rng'])
    print("Loaded checkpoint at step %s" % start_step)
else:
    start_step = 0

logger = setup_logs(args.output_dir, run_name)  # setup logs
writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'tensorboard'))
tr_losses, tr_accs = [], []
for step in range(start_step, n_steps):
def fit(config,
        logger,
        model,
        dataloaders,
        losses,
        optimizer,
        callbacks,
        lr_scheduler=None,
        is_inception=False,
        resume_from='L'):

    since = time.time()

    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []

    best_state = copy.deepcopy(model.state_dict())
    best_state = model.state_dict()

    # Load Checkpoint?
    start_epoch, best_acc, model, optimizer, reps = load_checkpoint(
        config=config,
        resume_from=resume_from,
        model=model,
        optimizer=optimizer)

    if hasattr(losses['train'], 'reps') and reps is not None:
        losses['train'].set_reps(reps)

    for callback in callbacks['training_start']:
        callback(0,
                 0,
                 0,
                 model,
                 dataloaders,
                 losses,
                 optimizer,
                 data={},
                 stats={})

    step = start_epoch * len(dataloaders['train'])
    for epoch in range(start_epoch, config.train.epochs):
        print('Epoch {}/{}'.format(epoch, config.train.epochs - 1))
        logger.info('Epoch {}/{}'.format(epoch, config.train.epochs - 1))
        print('-' * 10)
        logger.info('-' * 10)

        for callback in callbacks['epoch_start']:
            callback(epoch,
                     0,
                     step,
                     model,
                     dataloaders,
                     losses,
                     optimizer,
                     data={},
                     stats={})

        # Iterate over data.
        model.train()
        batch = 0
        for inputs, labels in dataloaders[
                'train']:  # this gets a batch (or an episode)
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            model.zero_grad()

            # forward
            # Get model outputs and calculate loss
            if is_inception:
                # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                outputs, aux_outputs = model(inputs)
                loss_main, sample_losses_main, pred, acc = losses['train'](
                    input=outputs, target=labels)
                loss_aux, sample_losses_aux, pred_aux, acc_aux = losses[
                    'train'](input=aux_outputs, target=labels)
                loss = loss_main + 0.4 * loss_aux
                sample_losses = sample_losses_main + 0.4 * sample_losses_aux
            else:
                outputs = model(inputs)
                loss, sample_losses, pred, acc = losses['train'](input=outputs,
                                                                 target=labels)

            # backward + optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # statistics
            train_loss.append(loss.item())
            train_acc.append(acc.item())

            for callback in callbacks['batch_end']:
                callback(epoch,
                         batch,
                         step,
                         model,
                         dataloaders,
                         losses,
                         optimizer,
                         data={
                             'inputs': inputs,
                             'outputs': outputs,
                             'labels': labels
                         },
                         stats={
                             'Training_Loss': train_loss[-1],
                             'Training_Acc': train_acc[-1],
                             'sample_losses': sample_losses
                         })

            batch += 1
            step += 1

        avg_loss = np.mean(train_loss[-batch:])
        avg_acc = np.mean(train_acc[-batch:])

        print('Avg Training Loss: {:.4f} Acc: {:.4f}'.format(
            avg_loss, avg_acc))
        logger.info('Avg Training Loss: {:.4f} Acc: {:.4f}'.format(
            avg_loss, avg_acc))
        if lr_scheduler:
            lr_scheduler.step()

        # Validation?
        if config.val.every > 0 and (epoch + 1) % config.val.every == 0:

            for callback in callbacks['validation_start']:
                callback(epoch,
                         0,
                         step,
                         model,
                         dataloaders,
                         losses,
                         optimizer,
                         data={},
                         stats={})

            model.eval()
            v_batch = 0
            val_loss = []
            val_acc = []
            for v_inputs, v_labels in dataloaders['val']:
                v_inputs = v_inputs.to(device)
                v_labels = v_labels.to(device)

                with torch.set_grad_enabled(
                        False
                ):  # disables grad calculation as dont need it so can save mem
                    # Get model outputs and calculate loss
                    v_outputs = model(v_inputs)
                loss, sample_losses, pred, acc = losses['val'](input=v_outputs,
                                                               target=v_labels)

                # statistics
                val_loss.append(loss.item())
                val_acc.append(acc.item())

                for callback in callbacks['validation_batch_end']:
                    callback(
                        epoch,
                        batch,
                        step,
                        model,
                        dataloaders,
                        losses,
                        optimizer,  # todo should we make this v_batch?
                        data={
                            'inputs': v_inputs,
                            'outputs': v_outputs,
                            'labels': v_labels
                        },
                        stats={
                            'Validation_Loss': val_loss[-1],
                            'Validation_Acc': val_acc[-1]
                        })

                v_batch += 1

            avg_v_loss = np.mean(val_loss)
            avg_v_acc = np.mean(val_acc)

            print('Avg Validation Loss: {:.4f} Acc: {:.4f}'.format(
                avg_v_loss, avg_v_acc))
            logger.info('Avg Validation Loss: {:.4f} Acc: {:.4f}'.format(
                avg_v_loss, avg_v_acc))

            # Best validation accuracy yet?
            if avg_v_acc > best_acc:
                best_acc = avg_v_acc
                # best_state = copy.deepcopy(model.state_dict())
                best_state = model.state_dict()
                if hasattr(losses['train'], 'reps'):
                    reps = losses['train'].get_reps()
                else:
                    reps = None
                save_checkpoint(config,
                                epoch,
                                model,
                                optimizer,
                                best_acc,
                                reps=reps,
                                is_best=True)

            # End of validation callbacks
            for callback in callbacks['validation_end']:
                callback(epoch,
                         batch,
                         step,
                         model,
                         dataloaders,
                         losses,
                         optimizer,
                         data={
                             'inputs': v_inputs,
                             'outputs': v_outputs,
                             'labels': v_labels
                         },
                         stats={
                             'Avg_Validation_Loss': avg_v_loss,
                             'Avg_Validation_Acc': avg_v_acc
                         })

        # End of epoch callbacks
        for callback in callbacks['epoch_end']:
            callback(epoch,
                     batch,
                     step,
                     model,
                     dataloaders,
                     losses,
                     optimizer,
                     data={
                         'inputs': inputs,
                         'outputs': outputs,
                         'labels': labels
                     },
                     stats={
                         'Avg_Training_Loss': avg_loss,
                         'Avg_Training_Acc': avg_acc
                     })

        # Checkpoint?
        if config.train.checkpoint_every > 0 and epoch % config.train.checkpoint_every == 0:
            if hasattr(losses['train'], 'reps'):
                reps = losses['train'].get_reps()
            else:
                reps = None
            save_checkpoint(config,
                            epoch,
                            model,
                            optimizer,
                            best_acc,
                            reps=reps,
                            is_best=False)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, time_elapsed % 60))
    logger.info('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    logger.info('Best val Acc: {:4f}'.format(best_acc))

    for callback in callbacks['training_end']:
        callback(epoch,
                 batch,
                 step,
                 model,
                 dataloaders,
                 losses,
                 optimizer,
                 data={},
                 stats={})

    # If no validation we save the last model as best
    if config.val.every < 1:
        best_acc = avg_acc
        # best_state = copy.deepcopy(model.state_dict())
        best_state = model.state_dict()
        if hasattr(losses['train'], 'reps'):
            reps = losses['train'].get_reps()
        else:
            reps = None
        save_checkpoint(config,
                        epoch,
                        model,
                        optimizer,
                        best_acc,
                        reps=reps,
                        is_best=True)

    return model, best_state, best_acc, train_loss, train_acc, val_loss, val_acc
def fit(config,
        logger,
        model,
        dataloaders,
        losses,
        optimizer,
        callbacks,
        lr_scheduler=None,
        is_inception=False,
        resume_from='L'):

    # initilize the tensor holder here.
    im_data = torch.FloatTensor(1)
    im_info = torch.FloatTensor(1)
    num_boxes = torch.LongTensor(1)
    gt_boxes = torch.FloatTensor(1)

    # ship to cuda
    im_data = im_data.cuda()
    im_info = im_info.cuda()
    num_boxes = num_boxes.cuda()
    gt_boxes = gt_boxes.cuda()

    # make variable
    from torch.autograd import Variable
    im_data = Variable(im_data)
    im_info = Variable(im_info)
    num_boxes = Variable(num_boxes)
    gt_boxes = Variable(gt_boxes)

    since = time.time()

    train_loss = []
    train_rpn_loss_cls = []
    train_rpn_loss_box = []
    train_rcnn_loss_cls = []
    train_rcnn_loss_bbox = []
    train_rpn_acc = []
    train_rcnn_acc = []
    val_loss = []
    val_rpn_acc = []
    val_rcnn_acc = []

    best_state = copy.deepcopy(model.state_dict())
    best_state = model.state_dict()

    # Load Checkpoint?
    start_epoch, best_acc, model, optimizer, reps = load_checkpoint(
        config=config,
        resume_from=resume_from,
        model=model,
        optimizer=optimizer)

    if hasattr(losses['train'], 'reps') and reps is not None:
        losses['train'].set_reps(reps)

    for callback in callbacks['training_start']:
        callback(0,
                 0,
                 0,
                 model,
                 dataloaders,
                 losses,
                 optimizer,
                 data={},
                 stats={})

    step = start_epoch * len(dataloaders['train'])
    for epoch in range(start_epoch, config.train.epochs):
        print('Epoch {}/{}'.format(epoch, config.train.epochs - 1))
        logger.info('Epoch {}/{}'.format(epoch, config.train.epochs - 1))
        print('-' * 10)
        logger.info('-' * 10)

        for callback in callbacks['epoch_start']:
            callback(epoch,
                     0,
                     step,
                     model,
                     dataloaders,
                     losses,
                     optimizer,
                     data={},
                     stats={})

        # Iterate over data.
        model.train()
        batch = 0
        print('Doing %d batches...' % len(dataloaders['train']))
        for data in dataloaders['train']:  # this gets a batch (or an episode)
            # inputs = inputs.to(device)
            # labels = labels.to(device)

            im_data.data.resize_(data[0].size()).copy_(data[0])
            im_info.data.resize_(data[1].size()).copy_(data[1])
            gt_boxes.data.resize_(data[2].size()).copy_(data[2])
            num_boxes.data.resize_(data[3].size()).copy_(data[3])

            # zero the parameter gradients
            model.zero_grad()

            # forward
            # Get model outputs and calculate loss

            # outputs = model(inputs)
            gt_rois, rois, rois_label, cls_pred, bbox_pred, rpn_scores, rpn_bboxs, rpn_cls_scores, rpn_bbox_preds, anchors =\
                model(im_data, im_info, gt_boxes, num_boxes)

            outputs = (gt_rois, rois, rois_label, cls_pred, bbox_pred,
                       rpn_scores, rpn_bboxs, rpn_cls_scores, rpn_bbox_preds,
                       anchors)
            loss, rpn_loss_cls, rpn_loss_box, rcnn_loss_cls, rcnn_loss_bbox, rpn_acc, rcnn_acc = \
                losses['train'](input=outputs, target=[gt_boxes, num_boxes, im_info])

            # backward + optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # statistics
            train_loss.append(loss.item())
            train_rpn_loss_cls.append(rpn_loss_cls.mean().item())
            train_rpn_loss_box.append(rpn_loss_box.mean().item())
            train_rcnn_loss_cls.append(rcnn_loss_cls.mean().item())
            train_rcnn_loss_bbox.append(rcnn_loss_bbox.mean().item())
            train_rpn_acc.append(rpn_acc.item())
            train_rcnn_acc.append(rcnn_acc.item())

            for callback in callbacks['batch_end']:
                callback(epoch,
                         batch,
                         step,
                         model,
                         dataloaders,
                         losses,
                         optimizer,
                         data={
                             'inputs': data,
                             'outputs': None,
                             'labels': None
                         },
                         stats={
                             'Training_Loss': train_loss[-1],
                             'Training_RPN_Acc': train_rpn_acc[-1],
                             'Training_RCNN_Acc': train_rcnn_acc[-1],
                             'Training_RPN_Class_Loss': train_rpn_loss_cls[-1],
                             'Training_RPN_Box_Loss': train_rpn_loss_box[-1],
                             'Training_RCNN_Class_Loss':
                             train_rcnn_loss_cls[-1],
                             'Training_RCNN_Box_Loss': train_rcnn_loss_bbox[-1]
                         })

            batch += 1
            step += 1

        avg_loss = np.mean(train_loss[-batch:])
        avg_rpn_acc = np.mean(train_rpn_acc[-batch:])
        avg_rcnn_acc = np.mean(train_rcnn_acc[-batch:])

        print(
            'Avg Train: Total Loss {:.4f}, RPN Class Loss {:.4f}, RPN Box Loss {:.4f}, RPN Acc: {:.4f}, RCNN Class Loss {:.4f}, RCNN Box Loss {:.4f}, RCNN Acc: {:.4f}'
            .format(avg_loss, np.mean(train_rpn_loss_cls),
                    np.mean(train_rpn_loss_box), avg_rpn_acc,
                    np.mean(train_rcnn_loss_cls),
                    np.mean(train_rcnn_loss_bbox), avg_rcnn_acc))
        logger.info(
            'Avg Train: Total Loss {:.4f}, RPN Class Loss {:.4f}, RPN Box Loss {:.4f}, RPN Acc: {:.4f}, RCNN Class Loss {:.4f}, RCNN Box Loss {:.4f}, RCNN Acc: {:.4f}'
            .format(avg_loss, np.mean(train_rpn_loss_cls),
                    np.mean(train_rpn_loss_box), avg_rpn_acc,
                    np.mean(train_rcnn_loss_cls),
                    np.mean(train_rcnn_loss_bbox), avg_rcnn_acc))

        if lr_scheduler:
            lr_scheduler.step()

        # Validation?
        if config.val.every > 0 and (epoch + 1) % config.val.every == 0:

            for callback in callbacks['validation_start']:
                callback(epoch,
                         0,
                         step,
                         model,
                         dataloaders,
                         losses,
                         optimizer,
                         data={},
                         stats={})

            # model.eval()
            v_batch = 0
            val_loss = []
            val_acc = []
            val_rpn_loss_cls = []
            val_rpn_loss_box = []
            val_rcnn_loss_cls = []
            val_rcnn_loss_bbox = []
            print("Validation with %d batches" % len(dataloaders['val']))
            for data in dataloaders['val']:
                # v_inputs = v_inputs.to(device)
                # v_labels = v_labels.to(device)

                im_data.data.resize_(data[0].size()).copy_(data[0])
                im_info.data.resize_(data[1].size()).copy_(data[1])
                gt_boxes.data.resize_(data[2].size()).copy_(data[2])
                num_boxes.data.resize_(data[3].size()).copy_(data[3])

                # print(gt_boxes)

                with torch.set_grad_enabled(
                        False
                ):  # disables grad calculation as dont need it so can save mem
                    # Get model outputs and calculate loss

                    gt_rois, rois, rois_label, cls_pred, bbox_pred, rpn_scores, rpn_bboxs, rpn_cls_scores, rpn_bbox_preds, anchors = \
                        model(im_data, im_info, gt_boxes, num_boxes)

                    outputs = (gt_rois, rois, rois_label, cls_pred, bbox_pred,
                               rpn_scores, rpn_bboxs, rpn_cls_scores,
                               rpn_bbox_preds, anchors)
                    loss, rpn_loss_cls, rpn_loss_box, rcnn_loss_cls, rcnn_loss_bbox, rpn_acc, rcnn_acc = \
                        losses['val'](input=outputs, target=[gt_boxes, num_boxes, im_info])

                # statistics
                val_loss.append(loss.item())
                val_rpn_acc.append(rpn_acc.item())
                val_rcnn_acc.append(rcnn_acc.item())
                val_rpn_loss_cls.append(rpn_loss_cls.mean().item())
                val_rpn_loss_box.append(rpn_loss_box.mean().item())
                val_rcnn_loss_cls.append(rcnn_loss_cls.mean().item())
                val_rcnn_loss_bbox.append(rcnn_loss_bbox.mean().item())

                for callback in callbacks['validation_batch_end']:
                    callback(
                        epoch,
                        batch,
                        step,
                        model,
                        dataloaders,
                        losses,
                        optimizer,  # todo should we make this v_batch?
                        data={
                            'inputs': data,
                            'outputs': None,
                            'labels': None
                        },
                        stats={
                            'Validation_Loss': val_loss[-1],
                            'Validation_RPN_Acc': val_rpn_acc[-1],
                            'Validation_RCNN_Acc': val_rcnn_acc[-1],
                            'Validation_RPN_Class_Loss': val_rpn_loss_cls[-1],
                            'Validation_RPN_Box_Loss': val_rpn_loss_box[-1],
                            'Validation_RCNN_Class_Loss':
                            val_rcnn_loss_cls[-1],
                            'Validation_RCNN_Box_Loss': val_rcnn_loss_bbox[-1]
                        })

                v_batch += 1

            avg_v_loss = np.mean(val_loss)
            avg_v_rpn_acc = np.mean(val_rpn_acc)
            avg_v_rcnn_acc = np.mean(val_rcnn_acc)

            print(
                'Avg Validation: Total Loss {:.4f}, RPN Class Loss {:.4f}, RPN Box Loss {:.4f}, RPN Acc: {:.4f}, RCNN Class Loss {:.4f}, RCNN Box Loss {:.4f}, RCNN Acc: {:.4f}'
                .format(avg_v_loss, np.mean(val_rpn_loss_cls),
                        np.mean(val_rpn_loss_box), avg_v_rpn_acc,
                        np.mean(val_rcnn_loss_cls),
                        np.mean(val_rcnn_loss_bbox), avg_v_rcnn_acc))
            logger.info(
                'Avg Validation: Total Loss {:.4f}, RPN Class Loss {:.4f}, RPN Box Loss {:.4f}, RPN Acc: {:.4f}, RCNN Class Loss {:.4f}, RCNN Box Loss {:.4f}, RCNN Acc: {:.4f}'
                .format(avg_v_loss, np.mean(val_rpn_loss_cls),
                        np.mean(val_rpn_loss_box), avg_v_rpn_acc,
                        np.mean(val_rcnn_loss_cls),
                        np.mean(val_rcnn_loss_bbox), avg_v_rcnn_acc))

            # Best validation accuracy yet?
            if avg_v_rcnn_acc > best_acc:
                best_acc = avg_v_rcnn_acc
                # best_state = copy.deepcopy(model.state_dict())
                best_state = model.state_dict()
                if hasattr(losses['train'], 'reps'):
                    reps = losses['train'].get_reps()
                else:
                    reps = None
                save_checkpoint(config,
                                epoch,
                                model,
                                optimizer,
                                best_acc,
                                reps=reps,
                                is_best=True)

            # End of validation callbacks
            for callback in callbacks['validation_end']:
                callback(epoch,
                         batch,
                         step,
                         model,
                         dataloaders,
                         losses,
                         optimizer,
                         data={
                             'inputs': None,
                             'outputs': None,
                             'labels': None
                         },
                         stats={
                             'Avg_Validation_Loss':
                             avg_v_loss,
                             'Avg_Validation_RPN_Acc':
                             avg_v_rpn_acc,
                             'Avg_Validation_RCNN_Acc':
                             avg_v_rcnn_acc,
                             'Avg_Validation_RPN_Class_Loss':
                             np.mean(val_rpn_loss_cls),
                             'Avg_Validation_RPN_Box_Loss':
                             np.mean(val_rpn_loss_box),
                             'Avg_Validation_RCNN_Class_Loss':
                             np.mean(val_rcnn_loss_cls),
                             'Avg_Validation_RCNN_Box_Loss':
                             np.mean(val_rcnn_loss_bbox)
                         })

        # End of epoch callbacks
        for callback in callbacks['epoch_end']:
            callback(epoch,
                     batch,
                     step,
                     model,
                     dataloaders,
                     losses,
                     optimizer,
                     data={
                         'inputs': None,
                         'outputs': None,
                         'labels': None
                     },
                     stats={
                         'Avg_Training_Loss': avg_loss,
                         'Avg_Training_Acc': avg_rcnn_acc
                     })

        # Checkpoint?
        if config.train.checkpoint_every > 0 and epoch % config.train.checkpoint_every == 0:
            if hasattr(losses['train'], 'reps'):
                reps = losses['train'].get_reps()
            else:
                reps = None
            save_checkpoint(config,
                            epoch,
                            model,
                            optimizer,
                            best_acc,
                            reps=reps,
                            is_best=False)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, time_elapsed % 60))
    logger.info('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    logger.info('Best val Acc: {:4f}'.format(best_acc))

    for callback in callbacks['training_end']:
        callback(epoch,
                 batch,
                 step,
                 model,
                 dataloaders,
                 losses,
                 optimizer,
                 data={},
                 stats={})

    # If no validation we save the last model as best
    if config.val.every < 1:
        best_acc = avg_rcnn_acc
        # best_state = copy.deepcopy(model.state_dict())
        best_state = model.state_dict()
        if hasattr(losses['train'], 'reps'):
            reps = losses['train'].get_reps()
        else:
            reps = None
        save_checkpoint(config,
                        epoch,
                        model,
                        optimizer,
                        best_acc,
                        reps=reps,
                        is_best=True)

    return model, best_state, best_acc, train_loss, avg_rcnn_acc, val_loss, val_acc