Пример #1
0
 def build_imageAudioNet(self, weights=''):
     net = ImageAudioModel()
     net.apply(weights_init)
     if len(weights) > 0:
         print('Loading the fc weights for imageAudio network')
         checkpointer = Checkpointer(net)
         checkpointer.load_model_only(weights)
     return net
Пример #2
0
 def build_imageAudioClassifierNet(self, net_classifier, args, weights=''):
     net = ImageAudioClassifyModel(net_classifier, args)
     net.apply(weights_init)
     if len(weights) > 0:
         print('Loading the weights for imageAudioClassifier network')
         checkpointer = Checkpointer(net)
         checkpointer.load_model_only(weights)
     return net
Пример #3
0
def main(args):
    logger = setup_logger(
        "Listen_to_look, classification",
        args.checkpoint_path,
        True
    )
    logger.debug(args)

    writer = None
    if args.visualization:
        writer = setup_tbx(
            args.checkpoint_path,
            True
        )
    if writer is not None:
        logger.info("Allowed Tensorboard writer")

    # create model
    builder = ModelBuilder()
    net_classifier = builder.build_classifierNet(512, args.num_classes).cuda()
    net_imageAudio = builder.build_imageAudioNet().cuda()
    net_imageAudioClassify = builder.build_imageAudioClassifierNet(net_imageAudio, net_classifier, args, weights=args.weights_audioImageModel).cuda()
    model = builder.build_audioPreviewLSTM(net_imageAudio, net_classifier, args)
    model = model.cuda()
    
    # define loss function (criterion) and optimizer
    criterion = {}
    criterion['CrossEntropyLoss'] = nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    checkpointer = Checkpointer(model)

    if args.pretrained_model is not None:
        if not os.path.isfile(args.pretrained_model): 
            list_of_models = glob.glob(os.path.join(args.pretrained_model, "*.pth"))
            args.pretrained_model = max(list_of_models, key=os.path.getctime)
        logger.debug("Loading model only at: {}".format(args.pretrained_model))
        checkpointer.load_model_only(f=args.pretrained_model)
    
    model = torch.nn.parallel.DataParallel(model).cuda()

    # DATA LOADING
    val_ds, val_collate = create_validation_dataset(args,logger=logger)
    val_loader = torch.utils.data.DataLoader(
        val_ds,
        batch_size=args.batch_size,
        num_workers=args.decode_threads,
        collate_fn=val_collate
    )

    avgpool_final_acc, lstm_final_acc, avgpool_mean_ap, lstm_mean_ap, loss_avg = validate(args, 117, val_loader, model, criterion, val_ds=val_ds)
    print(
        "Testing Summary for checkpoint: {}\n"
        "Avgpool Acc: {} \n LSTM Acc: {} \n Avgpool mAP: {} \n LSTM mAP: {}".format(
            args.pretrained_model, avgpool_final_acc*100,
            lstm_final_acc*100, avgpool_mean_ap, lstm_mean_ap
            )
    )
Пример #4
0
 def build_audio(self, weights=''):
     pretrained = True
     original_resnet = torchvision.models.resnet18(pretrained)
     net = AudioNet(original_resnet)
     if len(weights) > 0:
         print('Loading weights for audio stream')
         checkpointer = Checkpointer(net)
         checkpointer.load_model_only(weights)
     return net
Пример #5
0
 def build_classifierNet(
     self,
     input_dims=512,
     num_classes=200,
     weights='',
 ):
     net = ClassifierNet(input_dims, num_classes)
     net.apply(weights_init)
     if len(weights) > 0:
         print('Loading weights for classifier')
         checkpointer = Checkpointer(net)
         checkpointer.load_model_only(weights)
     return net
Пример #6
0
    def create_checkpointer(self):
        checkpoints_path = os.path.join(self.logs_path,
                                        Config.checkpoints_folder)
        os.makedirs(checkpoints_path)

        self.checkpointer = Checkpointer(checkpoints_path)
        logging.info(f'Checkpointer initialized at {checkpoints_path}')
Пример #7
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.args = args
        self.vis = Visualizer(env=args.checkname)
        self.saver = Checkpointer(args.checkname,
                                  args.saver_path,
                                  overwrite=False,
                                  verbose=True,
                                  timestamp=True,
                                  max_queue=args.max_save)

        self.model = LinkCrack()

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        if args.pretrained_model:
            self.model.load_state_dict(
                self.saver.load(self.args.pretrained_model, multi_gpu=True))
            self.vis.log('load checkpoint: %s' % self.args.pretrained_model,
                         'train info')

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_adam:
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=args.lr,
                                              weight_decay=args.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay)

        self.iter_counter = 0

        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # -------------------- Loss --------------------- #

        self.mask_loss = nn.BCEWithLogitsLoss(
            reduction='mean',
            pos_weight=torch.cuda.FloatTensor([args.pos_pixel_weight]))
        self.connected_loss = nn.BCEWithLogitsLoss(
            reduction='mean',
            pos_weight=torch.cuda.FloatTensor([args.pos_link_weight]))

        self.loss_weight = args.loss_weight

        # logger
        self.log_loss = {}
        self.log_acc = {}
        self.save_pos_acc = -1
        self.save_acc = -1
Пример #8
0
class Trainer(object):
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.args = args
        self.vis = Visualizer(env=args.checkname)
        self.saver = Checkpointer(args.checkname,
                                  args.saver_path,
                                  overwrite=False,
                                  verbose=True,
                                  timestamp=True,
                                  max_queue=args.max_save)

        self.model = LinkCrack()

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        if args.pretrained_model:
            self.model.load_state_dict(
                self.saver.load(self.args.pretrained_model, multi_gpu=True))
            self.vis.log('load checkpoint: %s' % self.args.pretrained_model,
                         'train info')

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_adam:
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=args.lr,
                                              weight_decay=args.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay)

        self.iter_counter = 0

        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # -------------------- Loss --------------------- #

        self.mask_loss = nn.BCEWithLogitsLoss(
            reduction='mean',
            pos_weight=torch.cuda.FloatTensor([args.pos_pixel_weight]))
        self.connected_loss = nn.BCEWithLogitsLoss(
            reduction='mean',
            pos_weight=torch.cuda.FloatTensor([args.pos_link_weight]))

        self.loss_weight = args.loss_weight

        # logger
        self.log_loss = {}
        self.log_acc = {}
        self.save_pos_acc = -1
        self.save_acc = -1

    def train_op(self, input, target):
        self.optimizer.zero_grad()

        mask = target[0]
        connected = target[1]

        pred = self.model(input)

        pred_mask = pred[0]
        pred_connected = pred[1]

        mask_loss = self.mask_loss(pred_mask.view(-1, 1), mask.view(
            -1, 1)) / self.args.train_batch_size
        connect_loss = self.connected_loss(pred_connected.view(
            -1, 1), connected.view(-1, 1)) / self.args.train_batch_size

        total_loss = mask_loss + self.loss_weight * connect_loss
        total_loss.backward()
        self.optimizer.step()

        self.iter_counter += 1

        self.log_loss = {
            'mask_loss': mask_loss.item(),
            'connect_loss': connect_loss.item(),
            'total_loss': total_loss.item()
        }

        return torch.cat((pred_mask.clone(), pred_connected.clone()), 1)

    def val_op(self, input, target):
        mask = target[0]
        connected = target[1]

        pred = self.model(input)

        pred_mask = pred[0]
        pred_connected = pred[1]

        mask_loss = self.mask_loss(pred_mask.view(-1, 1), mask.view(
            -1, 1)) / self.args.val_batch_size

        connect_loss = self.connected_loss(pred_connected, connected)
        total_loss = mask_loss + self.loss_weight * connect_loss

        self.log_loss = {
            'mask_loss': mask_loss.item(),
            'connect_loss': connect_loss.item(),
            'total_loss': total_loss.item()
        }

        return torch.cat((pred_mask.clone(), pred_connected.clone()), 1)

    def acc_op(self, pred, target):
        mask = target[0]
        connected = target[1]

        pred = torch.sigmoid(pred)
        pred[pred > self.args.acc_sigmoid_th] = 1
        pred[pred <= self.args.acc_sigmoid_th] = 0

        pred_mask = pred[:, 0, :, :].contiguous()
        pred_connected = pred[:, 1:, :, :].contiguous()

        mask_acc = pred_mask.eq(
            mask.view_as(pred_mask)).sum().item() / mask.numel()
        mask_pos_acc = pred_mask[mask > 0].eq(mask[mask > 0].view_as(
            pred_mask[mask > 0])).sum().item() / mask[mask > 0].numel()
        mask_neg_acc = pred_mask[mask < 1].eq(mask[mask < 1].view_as(
            pred_mask[mask < 1])).sum().item() / mask[mask < 1].numel()
        connected_acc = pred_connected.eq(connected.view_as(
            pred_connected)).sum().item() / connected.numel()
        connected_pos_acc = pred_connected[connected > 0].eq(
            connected[connected > 0].view_as(pred_connected[connected > 0])
        ).sum().item() / connected[connected > 0].numel()
        connected_neg_acc = pred_connected[connected < 1].eq(
            connected[connected < 1].view_as(pred_connected[connected < 1])
        ).sum().item() / connected[connected < 1].numel()

        self.log_acc = {
            'mask_acc': mask_acc,
            'mask_pos_acc': mask_pos_acc,
            'mask_neg_acc': mask_neg_acc,
            'connected_acc': connected_acc,
            'connected_pos_acc': connected_pos_acc,
            'connected_neg_acc': connected_neg_acc
        }

    def training(self):

        try:

            for epoch in range(1, self.args.epochs):
                self.vis.log('Start Epoch %d ...' % epoch, 'train info')
                self.model.train()
                # ---------------------  training ------------------- #
                bar = tqdm(enumerate(self.train_loader),
                           total=len(self.train_loader))
                bar.set_description('Epoch %d --- Training --- :' % epoch)
                for idx, sample in bar:
                    img = sample['image']
                    lab = sample['label']
                    self.scheduler(self.optimizer, idx, epoch, self.save_acc)
                    data, target = img.type(torch.cuda.FloatTensor).to(
                        self.device), [
                            lab[0].type(torch.cuda.FloatTensor).to(
                                self.device),
                            lab[1].type(torch.cuda.FloatTensor).to(self.device)
                        ]

                    pred = self.train_op(data, target)
                    if idx % self.args.vis_train_loss_every == 0:
                        self.vis.log(self.log_loss, 'train_loss')
                        self.vis.plot_many({
                            'train_mask_loss':
                            self.log_loss['mask_loss'],
                            'train_connect_loss':
                            self.log_loss['connect_loss'],
                            'train_total_loss':
                            self.log_loss['total_loss']
                        })

                    if idx % self.args.vis_train_acc_every == 0:
                        self.acc_op(pred, target)
                        self.vis.log(self.log_acc, 'train_acc')
                        self.vis.plot_many({
                            'train_mask_acc':
                            self.log_acc['mask_acc'],
                            'train_connect_acc':
                            self.log_acc['connected_acc'],
                            'train_mask_pos_acc':
                            self.log_acc['mask_pos_acc'],
                            'train_mask_neg_acc':
                            self.log_acc['mask_neg_acc'],
                            'train_connect_pos_acc':
                            self.log_acc['connected_pos_acc'],
                            'train_connect_neg_acc':
                            self.log_acc['connected_neg_acc']
                        })
                    if idx % self.args.vis_train_img_every == 0:
                        self.vis.img_many({
                            'train_img':
                            data.cpu(),
                            'train_pred':
                            pred[:, 0, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_lab':
                            target[0].unsqueeze(1).cpu(),
                            'train_lab_channel_0':
                            target[1][:,
                                      0, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_lab_channel_1':
                            target[1][:,
                                      1, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_lab_channel_2':
                            target[1][:,
                                      2, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_lab_channel_3':
                            target[1][:,
                                      3, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_lab_channel_4':
                            target[1][:,
                                      4, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_lab_channel_5':
                            target[1][:,
                                      5, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_lab_channel_6':
                            target[1][:,
                                      6, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_lab_channel_7':
                            target[1][:,
                                      7, :, :].unsqueeze(1).contiguous().cpu(),
                            'train_pred_channel_0':
                            torch.sigmoid(
                                pred[:,
                                     1, :, :].unsqueeze(1).contiguous().cpu()),
                            'train_pred_channel_1':
                            torch.sigmoid(
                                pred[:,
                                     2, :, :].unsqueeze(1).contiguous().cpu()),
                            'train_pred_channel_2':
                            torch.sigmoid(
                                pred[:,
                                     3, :, :].unsqueeze(1).contiguous().cpu()),
                            'train_pred_channel_3':
                            torch.sigmoid(
                                pred[:,
                                     4, :, :].unsqueeze(1).contiguous().cpu()),
                            'train_pred_channel_4':
                            torch.sigmoid(
                                pred[:,
                                     5, :, :].unsqueeze(1).contiguous().cpu()),
                            'train_pred_channel_5':
                            torch.sigmoid(
                                pred[:,
                                     6, :, :].unsqueeze(1).contiguous().cpu()),
                            'train_pred_channel_6':
                            torch.sigmoid(
                                pred[:,
                                     7, :, :].unsqueeze(1).contiguous().cpu()),
                            'train_pred_channel_7':
                            torch.sigmoid(
                                pred[:,
                                     8, :, :].unsqueeze(1).contiguous().cpu()),
                        })

                    if idx % self.args.val_every == 0:
                        self.vis.log('Start Val %d ....' % idx, 'train info')
                        # -------------------- val ------------------- #
                        self.model.eval()
                        val_loss = {
                            'mask_loss': 0,
                            'connect_loss': 0,
                            'total_loss': 0
                        }
                        val_acc = {
                            'mask_acc': 0,
                            'mask_pos_acc': 0,
                            'mask_neg_acc': 0,
                            'connected_acc': 0,
                            'connected_pos_acc': 0,
                            'connected_neg_acc': 0
                        }

                        bar.set_description('Epoch %d --- Evaluation --- :' %
                                            epoch)

                        with torch.no_grad():
                            for idx, sample in enumerate(self.val_loader,
                                                         start=1):
                                img = sample['image']
                                lab = sample['label']
                                val_data, val_target = img.type(
                                    torch.cuda.FloatTensor).to(self.device), [
                                        lab[0].type(torch.cuda.FloatTensor).to(
                                            self.device),
                                        lab[1].type(torch.cuda.FloatTensor).to(
                                            self.device)
                                    ]
                                val_pred = self.val_op(val_data, val_target)
                                self.acc_op(val_pred, val_target)
                                val_loss['mask_loss'] += self.log_loss[
                                    'mask_loss']
                                val_loss['connect_loss'] += self.log_loss[
                                    'connect_loss']
                                val_loss['total_loss'] += self.log_loss[
                                    'total_loss']
                                val_acc['mask_acc'] += self.log_acc['mask_acc']
                                val_acc['connected_acc'] += self.log_acc[
                                    'connected_acc']
                                val_acc['mask_pos_acc'] += self.log_acc[
                                    'mask_pos_acc']
                                val_acc['connected_pos_acc'] += self.log_acc[
                                    'connected_pos_acc']
                                val_acc['mask_neg_acc'] += self.log_acc[
                                    'mask_neg_acc']
                                val_acc['connected_neg_acc'] += self.log_acc[
                                    'connected_neg_acc']
                            else:
                                self.vis.img_many({
                                    'val_img':
                                    val_data.cpu(),
                                    'val_pred':
                                    val_pred[:,
                                             0, :, :].contiguous().unsqueeze(
                                                 1).cpu(),
                                    'val_lab':
                                    val_target[0].unsqueeze(1).cpu()
                                })
                                self.vis.plot_many({
                                    'val_mask_loss':
                                    val_loss['mask_loss'] / idx,
                                    'val_connect_loss':
                                    val_loss['connect_loss'] / idx,
                                    'val_total_loss':
                                    val_loss['total_loss'] / idx,
                                })
                                self.vis.plot_many({
                                    'val_mask_acc':
                                    val_acc['mask_acc'] / idx,
                                    'val_connect_acc':
                                    val_acc['connected_acc'] / idx,
                                    'val_mask_pos_acc':
                                    val_acc['mask_pos_acc'] / idx,
                                    'val_mask_neg_acc':
                                    val_acc['mask_neg_acc'] / idx,
                                    'val_connected_pos_acc':
                                    val_acc['connected_pos_acc'] / idx,
                                    'val_connected_neg_acc':
                                    val_acc['connected_neg_acc'] / idx
                                })
                        bar.set_description('Epoch %d --- Training --- :' %
                                            epoch)

                        # ----------------- save model ---------------- #
                        if self.save_pos_acc < (val_acc['mask_pos_acc'] / idx):
                            self.save_pos_acc = (val_acc['mask_pos_acc'] / idx)
                            self.save_acc = (val_acc['mask_acc'] / idx)
                            self.saver.save(
                                self.model,
                                tag='connected_weight(%f)_pos_acc(%0.5f)' %
                                (self.loss_weight,
                                 val_acc['mask_pos_acc'] / idx))
                            self.vis.log(
                                'Save Model -connected_weight(%f)_pos_acc(%0.5f)'
                                % (self.loss_weight,
                                   val_acc['mask_pos_acc'] / idx),
                                'train info')

                        if epoch % 5 == 0 and epoch != 0:
                            self.save_pos_acc = (val_acc['mask_pos_acc'] / idx)
                            self.save_acc = (val_acc['mask_acc'] / idx)
                            self.saver.save(
                                self.model,
                                tag=
                                'connected_weight(%f)_epoch(%d)_pos_acc(%0.5f)'
                                % (self.loss_weight, epoch,
                                   val_acc['mask_pos_acc'] / idx))
                            self.vis.log(
                                'Save Model -connected_weight(%f)_pos_acc(%0.5f)'
                                % (self.loss_weight,
                                   val_acc['mask_pos_acc'] / idx),
                                'train info')

                        # if idx % 1000 == 0:
                        #     self.save_pos_acc = (val_acc['mask_pos_acc'] / idx)
                        #     self.save_acc = (val_acc['mask_acc'] / idx)
                        #     self.saver.save(self.model, tag='connected_weight(%f)_epoch(%d)_pos_acc(%0.5f)' % (
                        #         self.loss_weight, epoch, val_acc['mask_pos_acc'] / idx))
                        #     self.vis.log('Save Model -connected_weight(%f)_pos_acc(%0.5f)' % (
                        #         self.loss_weight, val_acc['mask_pos_acc'] / idx), 'train info')

                        self.model.train()

        except KeyboardInterrupt:

            self.saver.save(self.model, tag='Auto_Save_Model')
            print('\n Catch KeyboardInterrupt, Auto Save final model : %s' %
                  self.saver.show_save_pth_name)
            self.vis.log(
                'Catch KeyboardInterrupt, Auto Save final model : %s' %
                self.saver.show_save_pth_name, 'train info')
            self.vis.log('Training End!!')
            try:
                sys.exit(0)
            except SystemExit:
                os._exit(0)
Пример #9
0
    args = parser.parse_args()

    config.merge_from_list(args.opts)
    config.freeze()

    save_dir = os.path.join(config.save_dir)
    mkdir(save_dir)
    logger = setup_logger("inference", save_dir, 0)
    logger.info("Running with config:\n{}".format(config))

    device = torch.device(config.device)
    num_types = len(config.boundaries) + 2

    generator = Generator(BertConfig(type_vocab_size=num_types))
    generator = generator.to(device)
    g_checkpointer = Checkpointer(model=generator, logger=logger)
    g_checkpointer.load(config.model_path, True)

    dataset = COCOCaptionDataset(root=config.data_dir,
                                 split='test',
                                 boundaries=config.boundaries)
    data_loader = make_data_loader(dataset=dataset,
                                   collate_fn=collate_fn_infer,
                                   batch_size=config.samples_per_gpu,
                                   num_workers=config.num_workers,
                                   split='test')

    pred_dict = inference(generator, data_loader, device)
    logger.info(f"Saving results to {save_dir}/caption_results.json")
    with open(os.path.join(save_dir, 'caption_results.json'), 'w') as f:
        json.dump(pred_dict, f)
Пример #10
0
        params=generator.parameters(),
        lr=config.solver.lr,
        weight_decay=config.solver.weight_decay,
        betas=config.solver.betas
    )

    scheduler = WarmupCosineSchedule(
        optimizer=optimizer,
        warmup_steps=config.scheduler.warmup_steps,
        t_total=config.scheduler.max_steps
    )

    checkpointer = Checkpointer(
        model=generator,
        optimizer=optimizer,
        scheduler=scheduler,
        save_dir=save_dir,
        save_to_disk=get_rank() == 0,
        logger=logger
    )

    if config.model_path == '':
        generator.load_weights(config.pretrained_bert)
    else:
        extra_checkpoint_data = checkpointer.load(config.model_path)
        arguments.update(extra_checkpoint_data)

    dataset = COCOCaptionDataset(
        root=config.data_dir,
        split='trainrestval',
        boundaries=config.boundaries,
    )
Пример #11
0
def main(args):

    os.makedirs(args.checkpoint_path, exist_ok=True)
    # Setup logging system
    logger = setup_logger(
        "Listen_to_look, audio_preview classification single modality",
        args.checkpoint_path, True)
    logger.debug(args)
    # Epoch logging
    epoch_log = setup_logger("Listen_to_look: results",
                             args.checkpoint_path,
                             True,
                             logname="epoch.log")
    epoch_log.info("epoch,loss,acc,lr")

    writer = None
    if args.visualization:
        writer = setup_tbx(args.checkpoint_path, True)
    if writer is not None:
        logger.info("Allowed Tensorboard writer")

    # Define the model
    builder = ModelBuilder()
    net_classifier = builder.build_classifierNet(args.embedding_size,
                                                 args.num_classes).cuda()
    net_imageAudioClassify = builder.build_imageAudioClassifierNet(
        net_classifier, args).cuda()
    model = builder.build_audioPreviewLSTM(net_classifier, args)
    model = model.cuda()

    # DATA LOADING
    train_ds, train_collate = create_training_dataset(args, logger=logger)
    val_ds, val_collate = create_validation_dataset(args, logger=logger)

    train_loader = torch.utils.data.DataLoader(train_ds,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.decode_threads,
                                               collate_fn=train_collate)
    val_loader = torch.utils.data.DataLoader(val_ds,
                                             batch_size=args.batch_size,
                                             num_workers=4,
                                             collate_fn=val_collate)

    args.iters_per_epoch = len(train_loader)
    args.warmup_iters = args.warmup_epochs * args.iters_per_epoch
    args.milestones = [args.iters_per_epoch * m for m in args.milestones]

    # define loss function (criterion) and optimizer
    criterion = {}
    criterion['CrossEntropyLoss'] = nn.CrossEntropyLoss().cuda()

    if args.freeze_imageAudioNet:
        param_groups = [{
            'params': model.queryfeature_mlp.parameters(),
            'lr': args.lr
        }, {
            'params': model.prediction_fc.parameters(),
            'lr': args.lr
        }, {
            'params': model.key_conv1x1.parameters(),
            'lr': args.lr
        }, {
            'params': model.rnn.parameters(),
            'lr': args.lr
        }, {
            'params': net_classifier.parameters(),
            'lr': args.lr
        }]
        optimizer = torch.optim.SGD(param_groups,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=1)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=1)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     args.milestones)
    # make optimizer scheduler
    if args.scheduler:
        scheduler = default_lr_scheduler(optimizer, args.milestones,
                                         args.warmup_iters)

    cudnn.benchmark = True

    # setting up the checkpointing system
    write_here = True
    checkpointer = Checkpointer(model,
                                optimizer,
                                save_dir=args.checkpoint_path,
                                save_to_disk=write_here,
                                scheduler=scheduler,
                                logger=logger)

    if args.pretrained_model is not None:
        logger.debug("Loading model only at: {}".format(args.pretrained_model))
        checkpointer.load_model_only(f=args.pretrained_model)

    if checkpointer.has_checkpoint():
        # call load checkpoint
        logger.debug("Loading last checkpoint")
        checkpointer.load()

    model = torch.nn.parallel.DataParallel(model).cuda()
    logger.debug(model)

    # Log all info
    if writer:
        writer.add_text("namespace", repr(args))
        writer.add_text("model", str(model))

    #
    # TRAINING
    #
    logger.debug("Entering the training loop")
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        train_accuracy, train_loss = train_epoch(args,
                                                 epoch,
                                                 train_loader,
                                                 model,
                                                 criterion,
                                                 optimizer,
                                                 scheduler,
                                                 logger,
                                                 epoch_logger=epoch_log,
                                                 checkpointer=checkpointer,
                                                 writer=writer)

        test_map, test_accuracy, test_loss, _ = validate(
            args,
            epoch,
            val_loader,
            model,
            criterion,
            epoch_logger=epoch_log,
            writer=writer)
        if writer is not None:
            writer.add_scalars('training_curves/accuracies', {
                'train': train_accuracy,
                'val': test_accuracy
            }, epoch)
            writer.add_scalars('training_curves/loss', {
                'train': train_loss,
                'val': test_loss
            }, epoch)