def main():

    import argparse
    parser = argparse.ArgumentParser(description="Pytorch Image CNN training from Configure Files")
    parser.add_argument('--config_file', required=True, help="This scripts only accepts parameters from Json files")
    input_args = parser.parse_args()

    config_file = input_args.config_file

    args = parse_config(config_file)
    if args.name is None:
        args.name = get_stem(config_file)

    torch.set_default_tensor_type('torch.FloatTensor')
    best_prec1 = 0

    args.script_name = get_stem(__file__)
    current_time_str = get_date_str()
    # if args.resume is None:
    if args.save_directory is None:
        save_directory = get_dir(os.path.join(project_root, 'ckpts2', '{:s}'.format(args.name), '{:s}-{:s}'.format(args.ID, current_time_str)))
    else:
        save_directory = get_dir(os.path.join(project_root, 'ckpts2', args.save_directory))
    # else:
    #     save_directory = os.path.dirname(args.resume)
    print("Save to {}".format(save_directory))
    log_file = os.path.join(save_directory, 'log-{0}.txt'.format(current_time_str))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)


    print_func = logger.info
    print_func('ConfigFile: {}'.format(config_file))
    args.log_file = log_file

    if args.device:
        os.environ["CUDA_VISIBLE_DEVICES"]=args.device


    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)

    if args.pretrained:
        print_func("=> using pre-trained model '{}'".format(args.arch))
        visual_model = models.__dict__[args.arch](pretrained=True, num_classes=args.num_classes)
    else:
        print_func("=> creating model '{}'".format(args.arch))
        visual_model = models.__dict__[args.arch](pretrained=False, num_classes=args.num_classes)

    if args.freeze:
        visual_model = CNN_utils.freeze_all_except_fc(visual_model)



    if os.path.isfile(args.text_ckpt):
        print_func("=> loading checkpoint '{}'".format(args.text_ckpt))
        text_data = torch.load(args.text_ckpt, map_location=lambda storage, loc:storage)
        text_model = TextCNN(text_data['args_model'])
        # load_state_dict(text_model, text_data['state_dict'])
        text_model.load_state_dict(text_data['state_dict'], strict=True)
        text_model.eval()
        print_func("=> loaded checkpoint '{}' for text classification"
              .format(args.text_ckpt))
        args.vocab_size = text_data['args_model'].vocab_size
    else:
        print_func("=> no checkpoint found at '{}'".format(args.text_ckpt))
        return


    args.tag2clsidx = text_data['args_data'].tag2idx
    args.vocab_size = len(args.tag2clsidx)

    args.text_embed = loadpickle(args.text_embed)
    args.idx2tag = loadpickle(args.idx2tag)['idx2tag']



    if args.gpu is not None:
        visual_model = visual_model.cuda(args.gpu)
        text_model = text_model.cuda((args.gpu))
    elif args.distributed:
        visual_model.cuda()
        visual_model = torch.nn.parallel.DistributedDataParallel(visual_model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            visual_model.features = torch.nn.DataParallel(visual_model.features)
            visual_model.cuda()
        else:
            visual_model = torch.nn.DataParallel(visual_model).cuda()
            text_model = torch.nn.DataParallel(text_model).cuda()


    criterion = nn.CrossEntropyLoss(ignore_index=-1).cuda(args.gpu)

    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, visual_model.parameters()), lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.lr_schedule:
        print_func("Using scheduled learning rate")
        scheduler = lr_scheduler.MultiStepLR(
            optimizer, [int(i) for i in args.lr_schedule.split(',')], gamma=0.1)
    else:
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=args.lr_patience)

    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_func("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            import collections
            if isinstance(checkpoint, collections.OrderedDict):
                load_state_dict(visual_model, checkpoint)


            else:
                load_state_dict(visual_model, checkpoint['state_dict'])
                print_func("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))

        else:
            print_func("=> no checkpoint found at '{}'".format(args.resume))



    cudnn.benchmark = True

    model_total_params = sum(p.numel() for p in visual_model.parameters())
    model_grad_params = sum(p.numel() for p in visual_model.parameters() if p.requires_grad)
    print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format(model_total_params, model_grad_params))

    # Data loading code
    val_dataset = get_instance(custom_datasets, '{0}'.format(args.valloader), args)
    if val_dataset is None:
        val_loader = None
    else:
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, pin_memory=True, collate_fn=none_collate)

    if args.evaluate:
        print_func('Validation Only')
        validate(val_loader, visual_model, criterion, args, print_func)
        return
    else:

        train_dataset = get_instance(custom_datasets, '{0}'.format(args.trainloader), args)

        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        else:
            train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=none_collate)




    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        if args.lr_schedule:
            # CNN_utils.adjust_learning_rate(optimizer, epoch, args.lr)
            scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        print_func("Epoch: [{}], learning rate: {}".format(epoch, current_lr))

        # train for one epoch
        train(train_loader, visual_model, text_model, criterion, optimizer, epoch, args, print_func)

        # evaluate on validation set
        if val_loader:
            prec1, val_loss = validate(val_loader, visual_model, criterion, args, print_func)
        else:
            prec1 = 0
            val_loss = 0
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        CNN_utils.save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': visual_model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best, file_directory=save_directory, epoch=epoch)

        if not args.lr_schedule:
            scheduler.step(val_loss)
Ejemplo n.º 2
0
def main():
    global args
    args = (parser.parse_args())
    use_cuda = cuda_model.ifUseCuda(args.gpu_id, args.multiGpu)
    script_name_stem = dir_utils.get_stem(__file__)
    if args.resume is None:

        save_directory = dir_utils.get_dir(os.path.join(project_root, 'ckpts', '{:s}'.format(args.dataset), '{:s}-{:s}-assgin{:.2f}-alpha{:.4f}-dim{:d}-dropout{:.4f}-seqlen{:d}-{:s}-{:s}'.
                                      format(script_name_stem, args.sufix, args.hassign_thres, args.alpha, args.hidden_dim, args.dropout, args.seq_len, 'L2', match_type[args.hmatch])))
    else:
        save_directory = args.resume

    log_file = os.path.join(save_directory, 'log-{:s}.txt'.format(dir_utils.get_date_str()))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)

    model = PointerNetwork(input_dim=args.input_dim, embedding_dim=args.embedding_dim,
                           hidden_dim=args.hidden_dim, max_decoding_len=args.net_outputs, dropout=args.dropout, n_enc_layers=2, output_classes=2)
    logger.info("Number of Params\t{:d}".format(sum([p.data.nelement() for p in model.parameters()])))
    logger.info('Saving logs to {:s}'.format(log_file))

    if args.resume is not None:

        ckpt_idx = args.fileid
        ckpt_filename = os.path.join(args.resume, 'checkpoint_{:04d}.pth.tar'.format(ckpt_idx))
        assert os.path.isfile(ckpt_filename), 'Error: no checkpoint directory found!'

        checkpoint = torch.load(ckpt_filename, map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        train_iou = checkpoint['IoU']
        args.start_epoch = checkpoint['epoch']

        logger.info("=> loading checkpoint '{}', current iou: {:.04f}".format(ckpt_filename, train_iou))


    model = cuda_model.convertModel2Cuda(model, gpu_id=args.gpu_id, multiGpu=args.multiGpu)
    train_dataset = cDataset(dataset_split='train', seq_length=args.seq_len, sample_rate=[4], rdOffset=True, rdDrop=True)
    val_dataset =cDataset(dataset_split='val', seq_length=args.seq_len, sample_rate=[4], rdDrop=False, rdOffset=False)


    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4)
    val_dataloader = DataLoader(val_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=4)

    model_optim = optim.Adam(filter(lambda p:p.requires_grad,  model.parameters()), lr=float(args.lr))
    optim_scheduler = optim.lr_scheduler.ReduceLROnPlateau(model_optim, 'min', patience=10)

    cls_weights = torch.FloatTensor([0.05, 1.0]).cuda()
    # cls_weights = None

    widgets = ['Test: ', ' -- [ ', progressbar.Counter(), '|', str(len(train_dataloader)), ' ] ',
               progressbar.Bar(), ' cls loss:  ', progressbar.FormatLabel(''),
               ' loc loss: ', progressbar.FormatLabel(''),
               ' IoU : ', progressbar.FormatLabel(''),
               ' (', progressbar.ETA(), ' ) ']

    # bar = progressbar.ProgressBar(max_value=step_per_epoch, widgets=widgets)
    # bar.start()

    for epoch in range(args.start_epoch, args.nof_epoch+args.start_epoch):

            total_losses = AverageMeter()
            loc_losses = AverageMeter()
            cls_losses = AverageMeter()
            matched_IOU = AverageMeter()
            true_IOU = AverageMeter()
            model.train()
            pbar = progressbar.ProgressBar(max_value=len(train_dataloader), widgets=widgets)
            pbar.start()
            for i_batch, sample_batch in enumerate(train_dataloader):
                # pbar.update(i_batch)

                feature_batch = Variable(sample_batch[0])
                start_indices = Variable(sample_batch[1])
                end_indices = Variable(sample_batch[2])
                gt_valids = Variable(sample_batch[3])
                # gt_overlaps = Variable(sample_batch[4])

                # seq_labels = Variable(sample_batch[3])

                if use_cuda:
                    feature_batch = feature_batch.cuda()
                    start_indices = start_indices.cuda()
                    end_indices = end_indices.cuda()

                gt_positions = torch.stack([start_indices, end_indices], dim=-1)


                head_pointer_probs, head_positions, tail_pointer_probs, tail_positions, cls_scores, _ = model(feature_batch)

                pred_positions = torch.stack([head_positions, tail_positions], dim=-1)
                # pred_scores = F.sigmoid(cls_scores)

                if args.hmatch:
                    assigned_scores, assigned_locations, total_valid, total_iou = h_match.Assign_Batch_v2(gt_positions, pred_positions, gt_valids, thres=args.hassign_thres)
                else:
                    #FIXME: do it later!
                    assigned_scores, assigned_locations, total_valid, total_iou = f_match.Assign_Batch_v2(gt_positions, pred_positions, gt_valids, thres=args.hassign_thres)
                    # _, _, total_valid, total_iou = h_match.Assign_Batch_v2(gt_positions, pred_positions, gt_valids, thres=args.hassign_thres)
                true_valid, true_iou = h_match.totalMatch_Batch(gt_positions, pred_positions, gt_valids)
                assert true_valid == total_valid, 'WRONG'

                if total_valid>0:
                    matched_IOU.update(total_iou / total_valid, total_valid)
                    true_IOU.update(true_iou/total_valid, total_valid)

                assigned_scores = Variable(torch.LongTensor(assigned_scores),requires_grad=False)
                # assigned_overlaps = Variable(torch.FloatTensor(assigned_overlaps), requires_grad=False)
                assigned_locations = Variable(torch.LongTensor(assigned_locations), requires_grad=False)
                if use_cuda:
                    assigned_scores = assigned_scores.cuda()
                    assigned_locations = assigned_locations.cuda()
                    # assigned_overlaps = assigned_overlaps.cuda()

                # pred_scores = pred_scores.contiguous().view(-1)
                # assigned_scores = assigned_scores.contiguous().view(-1)
                # assigned_overlaps = assigned_overlaps.contiguous().view(-1)
                # cls_loss = ClsLocLoss2_OneClsRegression(pred_scores, assigned_scores, assigned_overlaps)
                cls_scores = cls_scores.contiguous().view(-1, cls_scores.size()[-1])
                assigned_scores = assigned_scores.contiguous().view(-1)

                cls_loss = F.cross_entropy(cls_scores, assigned_scores, weight=cls_weights)

                if total_valid>0:
                    assigned_head_positions = assigned_locations[:,:,0]
                    assigned_head_positions = assigned_head_positions.contiguous().view(-1)
                    #
                    assigned_tail_positions = assigned_locations[:,:,1]
                    assigned_tail_positions = assigned_tail_positions.contiguous().view(-1)


                    head_pointer_probs = head_pointer_probs.contiguous().view(-1, head_pointer_probs.size()[-1])
                    tail_pointer_probs = tail_pointer_probs.contiguous().view(-1, tail_pointer_probs.size()[-1])

                    assigned_head_positions = torch.masked_select(assigned_head_positions, assigned_scores.byte())
                    assigned_tail_positions = torch.masked_select(assigned_tail_positions, assigned_scores.byte())

                    head_pointer_probs = torch.index_select(head_pointer_probs, dim=0, index=assigned_scores.nonzero().squeeze(1))
                    tail_pointer_probs = torch.index_select(tail_pointer_probs, dim=0, index=assigned_scores.nonzero().squeeze(1))

                    # if args.EMD:
                    assigned_head_positions = to_one_hot(assigned_head_positions, args.seq_len)
                    assigned_tail_positions = to_one_hot(assigned_tail_positions, args.seq_len)

                    prediction_head_loss = Simple_L2(head_pointer_probs, assigned_head_positions, needSoftMax=True)
                    prediction_tail_loss = Simple_L2(tail_pointer_probs, assigned_tail_positions, needSoftMax=True)
                    # else:
                    #     prediction_head_loss = F.cross_entropy(head_pointer_probs, assigned_head_positions)
                    #     prediction_tail_loss = F.cross_entropy(tail_pointer_probs, assigned_tail_positions)
                    loc_losses.update(prediction_head_loss.data.item() + prediction_tail_loss.data.item(),
                                      total_valid)#FIXME
                    total_loss =  args.alpha*(prediction_head_loss + prediction_tail_loss) + cls_loss
                else:
                    total_loss = cls_loss

                model_optim.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
                model_optim.step()
                cls_losses.update(cls_loss.data.item(), feature_batch.size(0))
                total_losses.update(total_loss.item(), feature_batch.size(0))

                widgets[-8] = progressbar.FormatLabel('{:04.4f}'.format(cls_losses.avg))
                widgets[-6] = progressbar.FormatLabel('{:04.4f}'.format(loc_losses.avg))
                widgets[-4] = progressbar.FormatLabel('{:01.4f}'.format(matched_IOU.avg))
                pbar.update(i_batch)


            logger.info(
                "Train -- Epoch :{:06d}, LR: {:.6f},\tloss={:.4f}, \t c-loss:{:.4f}, \tloc-loss:{:.4f}\tAvg-matched_IOU:{:.4f}\t Avg-true-IOU:{:.4f}".format(
                    epoch,
                    model_optim.param_groups[0]['lr'], total_losses.avg, cls_losses.avg, loc_losses.avg, matched_IOU.avg, true_IOU.avg))
            train_iou = matched_IOU.avg
            optim_scheduler.step(total_losses.avg)

            model.eval()

            matched_IOU = AverageMeter()
            pbar = progressbar.ProgressBar(max_value=len(val_dataloader))
            for i_batch, sample_batch in enumerate(val_dataloader):
                pbar.update(i_batch)

                feature_batch = Variable(sample_batch[0])
                start_indices = Variable(sample_batch[1])
                end_indices = Variable(sample_batch[2])
                gt_valids = Variable(sample_batch[3])
                # valid_indices = Variable(sample_batch[3])

                if use_cuda:
                    feature_batch = feature_batch.cuda()
                    start_indices = start_indices.cuda()
                    end_indices = end_indices.cuda()

                gt_positions = torch.stack([start_indices, end_indices], dim=-1)

                head_pointer_probs, head_positions, tail_pointer_probs, tail_positions, cls_scores, _ = model(
                    feature_batch)

                pred_positions = torch.stack([head_positions, tail_positions], dim=-1)

                # assigned_scores, assigned_locations, total_valid, total_iou = h_match.Assign_Batch_eval(gt_positions, pred_positions, gt_valids, thres=args.hassign_thres) #FIXME
                matched_valid, matched_iou = h_match.totalMatch_Batch(gt_positions, pred_positions, gt_valids)
                if matched_valid>0:
                    matched_IOU.update(matched_iou / matched_valid, matched_valid)

            logger.info(
                "Val -- Epoch :{:06d}, LR: {:.6f},\tloc-Avg-matched_IOU:{:.4f}".format(
                    epoch,model_optim.param_groups[0]['lr'], matched_IOU.avg, ))


            if epoch % 1 == 0 :
                save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'loss':total_losses.avg,
            'cls_loss': cls_losses.avg,
            'loc_loss': loc_losses.avg,
            'train-IOU':train_iou,
            'IoU': matched_IOU.avg}, (epoch+1), file_direcotry=save_directory)
Ejemplo n.º 3
0
def main():
    global args
    args = (parser.parse_args())
    use_cuda = cuda_model.ifUseCuda(args.gpu_id, args.multiGpu)
    script_name_stem = dir_utils.get_stem(__file__)
    save_directory = dir_utils.get_dir(
        os.path.join(
            project_root, 'ckpts',
            '{:s}-{:s}-{:s}-split-{:d}-claweight-{:s}-{:.1f}-assgin{:.2f}-alpha{:.4f}-dim{:d}-dropout{:.4f}-seqlen{:d}-samplerate-{:d}-{:s}-{:s}'
            .format(script_name_stem, args.dataset, args.eval_metrics,
                    args.split, str(args.set_cls_weight), args.cls_pos_weight,
                    args.hassign_thres, args.alpha, args.hidden_dim,
                    args.dropout, args.seq_len, args.sample_rate,
                    loss_type[args.EMD], match_type[args.hmatch])))
    log_file = os.path.join(save_directory,
                            'log-{:s}.txt'.format(dir_utils.get_date_str()))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)

    model = PointerNetwork(input_dim=args.input_dim,
                           embedding_dim=args.embedding_dim,
                           hidden_dim=args.hidden_dim,
                           max_decoding_len=args.net_outputs,
                           dropout=args.dropout,
                           n_enc_layers=2,
                           output_classes=2)
    hassign_thres = args.hassign_thres
    logger.info("Number of Params\t{:d}".format(
        sum([p.data.nelement() for p in model.parameters()])))
    logger.info('Saving logs to {:s}'.format(log_file))

    if args.resume is not None:

        ckpt_idx = 48

        ckpt_filename = args.resume.format(ckpt_idx)
        assert os.path.isfile(
            ckpt_filename), 'Error: no checkpoint directory found!'

        checkpoint = torch.load(ckpt_filename,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        train_iou = checkpoint['IoU']
        args.start_epoch = checkpoint['epoch']

        logger.info("=> loading checkpoint '{}', current iou: {:.04f}".format(
            ckpt_filename, train_iou))

    model = cuda_model.convertModel2Cuda(model,
                                         gpu_id=args.gpu_id,
                                         multiGpu=args.multiGpu)
    # get train/val split
    if args.dataset == 'SumMe':
        train_val_test_perms = np.arange(25)
    elif args.dataset == 'TVSum':
        train_val_test_perms = np.arange(50)
    # fixed permutation
    random.Random(0).shuffle(train_val_test_perms)
    train_val_test_perms = train_val_test_perms.reshape([5, -1])
    train_val_perms = np.delete(train_val_test_perms, args.split,
                                0).reshape([-1])
    train_perms = train_val_perms[:17]
    val_perms = train_val_perms[17:]
    test_perms = train_val_test_perms[args.split]
    logger.info(" training split: " + str(train_perms))
    logger.info(" val split: " + str(val_perms))
    logger.info(" test split: " + str(test_perms))

    if args.location == 'home':
        data_path = os.path.join(os.path.expanduser('~'), 'datasets')
    else:
        data_path = os.path.join('/nfs/%s/boyu/SDN' % (args.location),
                                 'datasets')
    train_dataset = vsSumLoader3_c3dd.cDataset(dataset_name=args.dataset,
                                               split='train',
                                               seq_length=args.seq_len,
                                               overlap=0.9,
                                               sample_rate=[args.sample_rate],
                                               train_val_perms=train_perms,
                                               data_path=data_path)
    val_evaluator = Evaluator.Evaluator(dataset_name=args.dataset,
                                        split='val',
                                        seq_length=args.seq_len,
                                        overlap=0.9,
                                        sample_rate=[args.sample_rate],
                                        sum_budget=0.15,
                                        train_val_perms=val_perms,
                                        eval_metrics=args.eval_metrics,
                                        data_path=data_path)
    test_evaluator = Evaluator.Evaluator(dataset_name=args.dataset,
                                         split='test',
                                         seq_length=args.seq_len,
                                         overlap=0.9,
                                         sample_rate=[args.sample_rate],
                                         sum_budget=0.15,
                                         train_val_perms=test_perms,
                                         eval_metrics=args.eval_metrics,
                                         data_path=data_path)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4)
    # val_dataloader = DataLoader(val_dataset,
    #                               batch_size=args.batch_size,
    #                               shuffle=False,
    #                               num_workers=4)

    model_optim = optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=float(args.lr))
    optim_scheduler = optim.lr_scheduler.ReduceLROnPlateau(model_optim,
                                                           'min',
                                                           patience=10)

    alpha = args.alpha
    # cls_weights = torch.FloatTensor([0.2, 1.0]).cuda()
    if args.set_cls_weight:
        cls_weights = torch.FloatTensor([
            1. * train_dataset.n_positive_train_samples /
            train_dataset.n_total_train_samples, args.cls_pos_weight
        ]).cuda()
    else:
        cls_weights = torch.FloatTensor([0.5, 0.5]).cuda()
    logger.info(" total: {:d}, total pos: {:d}".format(
        train_dataset.n_total_train_samples,
        train_dataset.n_positive_train_samples))
    logger.info(" classify weight: " + str(cls_weights[0]) +
                str(cls_weights[1]))
    for epoch in range(args.start_epoch, args.nof_epoch + args.start_epoch):
        total_losses = AverageMeter()
        loc_losses = AverageMeter()
        cls_losses = AverageMeter()
        Accuracy = AverageMeter()
        IOU = AverageMeter()
        ordered_IOU = AverageMeter()
        model.train()
        pbar = progressbar.ProgressBar(max_value=len(train_dataloader))
        for i_batch, sample_batch in enumerate(train_dataloader):
            pbar.update(i_batch)

            feature_batch = Variable(sample_batch[0])
            start_indices = Variable(sample_batch[1])
            end_indices = Variable(sample_batch[2])
            gt_valids = Variable(sample_batch[3])
            # seq_labels = Variable(sample_batch[3])

            if use_cuda:
                feature_batch = feature_batch.cuda()
                start_indices = start_indices.cuda()
                end_indices = end_indices.cuda()

            gt_positions = torch.stack([start_indices, end_indices], dim=-1)

            head_pointer_probs, head_positions, tail_pointer_probs, tail_positions, cls_scores, _ = model(
                feature_batch)

            pred_positions = torch.stack([head_positions, tail_positions],
                                         dim=-1)
            if args.hmatch:
                assigned_scores, assigned_locations, total_valid, total_iou = h_match.Assign_Batch_v2(
                    gt_positions,
                    pred_positions,
                    gt_valids,
                    thres=hassign_thres)

            else:
                assigned_scores, assigned_locations = f_match.Assign_Batch(
                    gt_positions,
                    pred_positions,
                    gt_valids,
                    thres=hassign_thres)
                _, _, total_valid, total_iou = h_match.Assign_Batch_v2(
                    gt_positions,
                    pred_positions,
                    gt_valids,
                    thres=hassign_thres)

            if total_valid > 0:
                IOU.update(total_iou / total_valid, total_valid)

            assigned_scores = Variable(torch.LongTensor(assigned_scores),
                                       requires_grad=False)
            assigned_locations = Variable(torch.LongTensor(assigned_locations),
                                          requires_grad=False)
            if use_cuda:
                assigned_scores = assigned_scores.cuda()
                assigned_locations = assigned_locations.cuda()

            cls_scores = cls_scores.contiguous().view(-1,
                                                      cls_scores.size()[-1])
            assigned_scores = assigned_scores.contiguous().view(-1)

            cls_loss = F.cross_entropy(cls_scores,
                                       assigned_scores,
                                       weight=cls_weights)

            if total_valid > 0:
                assigned_head_positions = assigned_locations[:, :, 0]
                assigned_head_positions = assigned_head_positions.contiguous(
                ).view(-1)
                #
                assigned_tail_positions = assigned_locations[:, :, 1]
                assigned_tail_positions = assigned_tail_positions.contiguous(
                ).view(-1)

                head_pointer_probs = head_pointer_probs.contiguous().view(
                    -1,
                    head_pointer_probs.size()[-1])
                tail_pointer_probs = tail_pointer_probs.contiguous().view(
                    -1,
                    tail_pointer_probs.size()[-1])

                assigned_head_positions = torch.masked_select(
                    assigned_head_positions, assigned_scores.byte())
                assigned_tail_positions = torch.masked_select(
                    assigned_tail_positions, assigned_scores.byte())

                head_pointer_probs = torch.index_select(
                    head_pointer_probs,
                    dim=0,
                    index=assigned_scores.nonzero().squeeze(1))
                tail_pointer_probs = torch.index_select(
                    tail_pointer_probs,
                    dim=0,
                    index=assigned_scores.nonzero().squeeze(1))

                if args.EMD:
                    assigned_head_positions = to_one_hot(
                        assigned_head_positions, args.seq_len)
                    assigned_tail_positions = to_one_hot(
                        assigned_tail_positions, args.seq_len)

                    prediction_head_loss = EMD_L2(head_pointer_probs,
                                                  assigned_head_positions,
                                                  needSoftMax=True)
                    prediction_tail_loss = EMD_L2(tail_pointer_probs,
                                                  assigned_tail_positions,
                                                  needSoftMax=True)
                else:
                    prediction_head_loss = F.cross_entropy(
                        head_pointer_probs, assigned_head_positions)
                    prediction_tail_loss = F.cross_entropy(
                        tail_pointer_probs, assigned_tail_positions)
                loc_losses.update(
                    prediction_head_loss.data.item() +
                    prediction_tail_loss.data.item(), feature_batch.size(0))
                total_loss = alpha * (prediction_head_loss +
                                      prediction_tail_loss) + cls_loss
            else:
                total_loss = cls_loss

            model_optim.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            model_optim.step()
            cls_losses.update(cls_loss.data.item(), feature_batch.size(0))
            total_losses.update(total_loss.item(), feature_batch.size(0))

        logger.info(
            "Train -- Epoch :{:06d}, LR: {:.6f},\tloss={:.4f}, \t c-loss:{:.4f}, \tloc-loss:{:.4f}\tcls-Accuracy:{:.4f}\tloc-Avg-IOU:{:.4f}\t topIOU:{:.4f}"
            .format(epoch, model_optim.param_groups[0]['lr'], total_losses.avg,
                    cls_losses.avg, loc_losses.avg, Accuracy.avg, IOU.avg,
                    ordered_IOU.avg))

        optim_scheduler.step(total_losses.avg)

        model.eval()

        # IOU = AverageMeter()
        # pbar = progressbar.ProgressBar(max_value=len(val_evaluator))
        # for i_batch, sample_batch in enumerate(val_dataloader):
        #     pbar.update(i_batch)

        #     feature_batch = Variable(sample_batch[0])
        #     start_indices = Variable(sample_batch[1])
        #     end_indices = Variable(sample_batch[2])
        #     gt_valids = Variable(sample_batch[3])
        #     # valid_indices = Variable(sample_batch[3])

        #     if use_cuda:
        #         feature_batch = feature_batch.cuda()
        #         start_indices = start_indices.cuda()
        #         end_indices = end_indices.cuda()

        #     gt_positions = torch.stack([start_indices, end_indices], dim=-1)

        #     head_pointer_probs, head_positions, tail_pointer_probs, tail_positions, cls_scores, _ = model(
        #         feature_batch)#Update: compared to the previous version, we now update the matching rules

        #     pred_positions = torch.stack([head_positions, tail_positions], dim=-1)
        #     pred_scores = cls_scores[:, :, -1]
        #     #TODO: should NOT change here for evaluation!
        #     assigned_scores, assigned_locations, total_valid, total_iou = h_match.Assign_Batch_v2(gt_positions, pred_positions, gt_valids, thres=hassign_thres)
        #     if total_valid>0:
        #         IOU.update(total_iou / total_valid, total_valid)

        val_F1s = val_evaluator.Evaluate(model)
        test_F1s = test_evaluator.Evaluate(model)

        logger.info("Val -- Epoch :{:06d}, LR: {:.6f},\tF1s:{:.4f}".format(
            epoch, model_optim.param_groups[0]['lr'], val_F1s))
        logger.info("Test -- Epoch :{:06d},\tF1s:{:.4f}".format(
            epoch, test_F1s))

        if epoch % 1 == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'loss': total_losses.avg,
                    'cls_loss': cls_losses.avg,
                    'loc_loss': loc_losses.avg,
                    'IoU': IOU.avg,
                    'val_F1s': val_F1s,
                    'test_F1s': test_F1s
                }, (epoch + 1),
                file_direcotry=save_directory)
def main():
    best_prec1 = 0

    args = parser.parser.parse_args()
    if args.config is not None:
        args = parse_config(args.config)

    script_name_stem = get_stem(__file__)
    current_time_str = get_date_str()

    if args.save_directory is None:
        raise FileNotFoundError(
            "Saving directory should be specified for feature extraction tasks"
        )
    save_directory = get_dir(args.save_directory)

    print("Save to {}".format(save_directory))
    log_file = os.path.join(save_directory,
                            'log-{0}.txt'.format(current_time_str))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)
    print_func = logger.info
    args.log_file = log_file

    if args.device:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn(
            'You have chosen a specific GPU. This will completely disable data parallelism.'
        )

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    if args.arch == 'resnet50_feature_extractor':
        print_func("=> using pre-trained model '{}' to LOAD FEATURES".format(
            args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           num_classes=args.num_classes,
                                           param_name=args.paramname)

    else:
        print_func(
            "This is only for feature extractors!, Please double check the parameters!"
        )
        return

    # if args.freeze:
    #     model = CNN_utils.freeze_all_except_fc(model)

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    else:
        print_func(
            'Please only specify one GPU since we are working in batch size 1 model'
        )
        return

    cudnn.benchmark = True

    model_total_params = sum(p.numel() for p in model.parameters())
    model_grad_params = sum(p.numel() for p in model.parameters()
                            if p.requires_grad)
    print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format(
        model_total_params, model_grad_params))

    # Data loading code
    val_dataset = get_instance(custom_datasets,
                               '{0}'.format(args.dataset.name), args,
                               **args.dataset.args)
    import tqdm
    import numpy as np

    if args.individual_feat:
        feature_save_directory = get_dir(
            os.path.join(save_directory, 'individual-features'))
        created_paths = set()
    else:
        data_dict = {}
        feature_save_directory = os.path.join(save_directory, 'feature.pkl')

    model.eval()
    for s_data in tqdm.tqdm(val_dataset, desc="Extracting Features"):
        if s_data is None:
            continue
        s_image_name = s_data[1]
        s_image_data = s_data[0]
        if args.gpu is not None:
            s_image_data = s_image_data.cuda(args.gpu, non_blocking=True)

        output = model(s_image_data.unsqueeze_(0))
        output = output.cpu().data.numpy()
        image_rel_path = os.path.join(
            *(s_image_name.split(os.sep)[-args.rel_path_depth:]))

        if args.individual_feat:
            image_directory = os.path.dirname(image_rel_path)
            if image_directory in created_paths:
                np.save(
                    os.path.join(feature_save_directory,
                                 '{}.npy'.format(image_rel_path)), output)
            else:
                get_dir(os.path.join(feature_save_directory, image_directory))
                np.save(
                    os.path.join(feature_save_directory,
                                 '{}.npy'.format(image_rel_path)), output)
                created_paths.add(image_directory)
        else:
            data_dict[image_rel_path] = output
        # image_name = os.path.basename(s_image_name)
        #
        # if args.individual_feat:
        #         # image_name = os.path.basename(s_image_name)
        #
        #         np.save(os.path.join(feature_save_directory, '{}.npy'.format(image_name)), output)
        #         # created_paths.add(image_directory)
        # else:
        #         data_dict[get_stem(image_name)] = output

    if args.individual_feat:
        print_func("Done")
    else:
        from PyUtils.pickle_utils import save2pickle
        print_func("Saving to a single big file!")

        save2pickle(feature_save_directory, data_dict)
        print_func("Done")
def main():
    global args
    args = (parser.parse_args())
    use_cuda = cuda_model.ifUseCuda(args.gpu_id, args.multiGpu)
    script_name_stem = dir_utils.get_stem(__file__)
    save_directory = dir_utils.get_dir(
        os.path.join(
            project_root, 'ckpts',
            'Delete-{:s}-assgin{:.2f}-alpha{:.4f}-dim{:d}-dropout{:.4f}-seqlen{:d}-{:s}-{:s}'
            .format(script_name_stem, args.hassign_thres, args.alpha,
                    args.hidden_dim, args.dropout, args.seq_len,
                    loss_type[args.EMD], match_type[args.hmatch])))
    log_file = os.path.join(save_directory,
                            'log-{:s}.txt'.format(dir_utils.get_date_str()))
    logger = chinese_utils.get_logger(log_file)
    chinese_utils.print_config(vars(args), logger)

    model = PointerNetwork(input_dim=args.input_dim,
                           embedding_dim=args.embedding_dim,
                           hidden_dim=args.hidden_dim,
                           max_decoding_len=args.net_outputs,
                           dropout=args.dropout,
                           n_enc_layers=2)
    hassign_thres = args.hassign_thres
    logger.info("Number of Params\t{:d}".format(
        sum([p.data.nelement() for p in model.parameters()])))
    logger.info('Saving logs to {:s}'.format(log_file))

    if args.resume is not None:

        ckpt_idx = 48

        ckpt_filename = args.resume.format(ckpt_idx)
        assert os.path.isfile(
            ckpt_filename), 'Error: no checkpoint directory found!'

        checkpoint = torch.load(ckpt_filename,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        train_iou = checkpoint['IoU']
        args.start_epoch = checkpoint['epoch']

        logger.info("=> loading checkpoint '{}', current iou: {:.04f}".format(
            ckpt_filename, train_iou))

    model = cuda_model.convertModel2Cuda(model,
                                         gpu_id=args.gpu_id,
                                         multiGpu=args.multiGpu)

    train_dataset = cDataset(dataset_split='train')
    val_dataset = cDataset(dataset_split='val')

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=4)

    model_optim = optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=float(args.lr))
    optim_scheduler = optim.lr_scheduler.ReduceLROnPlateau(model_optim,
                                                           'min',
                                                           patience=10)

    alpha = args.alpha
    # cls_weights = torch.FloatTensor([0.05, 1.0]).cuda()
    for epoch in range(args.start_epoch, args.nof_epoch + args.start_epoch):
        total_losses = AverageMeter()
        loc_losses = AverageMeter()
        cls_losses = AverageMeter()
        Accuracy = AverageMeter()
        IOU = AverageMeter()
        ordered_IOU = AverageMeter()
        model.train()
        pbar = progressbar.ProgressBar(max_value=len(train_dataloader))
        for i_batch, sample_batch in enumerate(train_dataloader):
            pbar.update(i_batch)

            feature_batch = Variable(sample_batch[0])
            start_indices = Variable(sample_batch[1])
            end_indices = Variable(sample_batch[2])
            gt_valids = Variable(sample_batch[3])
            # seq_labels = Variable(sample_batch[4])

            if use_cuda:
                feature_batch = feature_batch.cuda()
                start_indices = start_indices.cuda()
                end_indices = end_indices.cuda()

            gt_positions = torch.stack([start_indices, end_indices], dim=-1)

            head_pointer_probs, head_positions, tail_pointer_probs, tail_positions, cls_scores, _ = model(
                feature_batch)

            pred_positions = torch.stack([head_positions, tail_positions],
                                         dim=-1)
            if args.hmatch:
                assigned_scores, assigned_locations, total_valid, total_iou = h_match.Assign_Batch_v2(
                    gt_positions,
                    pred_positions,
                    gt_valids,
                    thres=hassign_thres)
                IOU.update(total_iou / total_valid, total_valid)
            else:
                assigned_scores, assigned_locations = f_match.Assign_Batch(
                    gt_positions,
                    pred_positions,
                    gt_valids,
                    thres=hassign_thres)
                _, _, total_valid, total_iou = h_match.Assign_Batch_v2(
                    gt_positions,
                    pred_positions,
                    gt_valids,
                    thres=hassign_thres)
                IOU.update(total_iou / total_valid, total_valid)

            assigned_scores = Variable(torch.LongTensor(assigned_scores),
                                       requires_grad=False)
            assigned_locations = Variable(torch.LongTensor(assigned_locations),
                                          requires_grad=False)
            if use_cuda:
                assigned_scores = assigned_scores.cuda()
                assigned_locations = assigned_locations.cuda()

            cls_scores = cls_scores.contiguous().view(-1,
                                                      cls_scores.size()[-1])
            assigned_scores = assigned_scores.contiguous().view(-1)

            cls_loss = F.cross_entropy(cls_scores, assigned_scores)

            if total_valid > 0:
                assigned_head_positions = assigned_locations[:, :, 0]
                assigned_head_positions = assigned_head_positions.contiguous(
                ).view(-1)
                #
                assigned_tail_positions = assigned_locations[:, :, 1]
                assigned_tail_positions = assigned_tail_positions.contiguous(
                ).view(-1)

                head_pointer_probs = head_pointer_probs.contiguous().view(
                    -1,
                    head_pointer_probs.size()[-1])
                tail_pointer_probs = tail_pointer_probs.contiguous().view(
                    -1,
                    tail_pointer_probs.size()[-1])

                assigned_head_positions = torch.masked_select(
                    assigned_head_positions, assigned_scores.byte())
                assigned_tail_positions = torch.masked_select(
                    assigned_tail_positions, assigned_scores.byte())

                head_pointer_probs = torch.index_select(
                    head_pointer_probs,
                    dim=0,
                    index=assigned_scores.nonzero().squeeze(1))
                tail_pointer_probs = torch.index_select(
                    tail_pointer_probs,
                    dim=0,
                    index=assigned_scores.nonzero().squeeze(1))

                if args.EMD:
                    assigned_head_positions = to_one_hot(
                        assigned_head_positions, args.seq_len)
                    assigned_tail_positions = to_one_hot(
                        assigned_tail_positions, args.seq_len)

                    prediction_head_loss = EMD_L2(head_pointer_probs,
                                                  assigned_head_positions,
                                                  needSoftMax=True)
                    prediction_tail_loss = EMD_L2(tail_pointer_probs,
                                                  assigned_tail_positions,
                                                  needSoftMax=True)
                else:
                    prediction_head_loss = F.cross_entropy(
                        head_pointer_probs, assigned_head_positions)
                    prediction_tail_loss = F.cross_entropy(
                        tail_pointer_probs, assigned_tail_positions)
                loc_losses.update(
                    prediction_head_loss.data.item() +
                    prediction_tail_loss.data.item(), feature_batch.size(0))
                total_loss = alpha * (prediction_head_loss +
                                      prediction_tail_loss) + cls_loss
            else:
                total_loss = cls_loss

            model_optim.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            model_optim.step()
            cls_losses.update(cls_loss.data.item(), feature_batch.size(0))
            total_losses.update(total_loss.item(), feature_batch.size(0))

        logger.info(
            "Train -- Epoch :{:06d}, LR: {:.6f},\tloss={:.4f}, \t c-loss:{:.4f}, \tloc-loss:{:.4f}\tcls-Accuracy:{:.4f}\tloc-Avg-IOU:{:.4f}\t topIOU:{:.4f}"
            .format(epoch, model_optim.param_groups[0]['lr'], total_losses.avg,
                    cls_losses.avg, loc_losses.avg, Accuracy.avg, IOU.avg,
                    ordered_IOU.avg))

        optim_scheduler.step(total_losses.avg)

        model.eval()

        IOU = AverageMeter()
        pbar = progressbar.ProgressBar(max_value=len(val_dataloader))
        for i_batch, sample_batch in enumerate(val_dataloader):
            pbar.update(i_batch)

            feature_batch = Variable(sample_batch[0])
            start_indices = Variable(sample_batch[1])
            end_indices = Variable(sample_batch[2])
            gt_valids = Variable(sample_batch[3])
            # valid_indices = Variable(sample_batch[4])

            if use_cuda:
                feature_batch = feature_batch.cuda()
                start_indices = start_indices.cuda()
                end_indices = end_indices.cuda()

            gt_positions = torch.stack([start_indices, end_indices], dim=-1)

            head_pointer_probs, head_positions, tail_pointer_probs, tail_positions, cls_scores, _ = model(
                feature_batch
            )  #Update: compared to the previous version, we now update the matching rules

            pred_positions = torch.stack([head_positions, tail_positions],
                                         dim=-1)

            #TODO: should NOT change here for evaluation!
            assigned_scores, assigned_locations, total_valid, total_iou = h_match.Assign_Batch_eval(
                gt_positions, pred_positions, gt_valids, thres=hassign_thres)
            IOU.update(total_iou / total_valid, total_valid)

        logger.info(
            "Val -- Epoch :{:06d}, LR: {:.6f},\tloc-Avg-IOU:{:.4f}".format(
                epoch,
                model_optim.param_groups[0]['lr'],
                IOU.avg,
            ))

        if epoch % 1 == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'loss': total_losses.avg,
                    'cls_loss': cls_losses.avg,
                    'loc_loss': loc_losses.avg,
                    'IoU': IOU.avg
                }, (epoch + 1),
                file_direcotry=save_directory)
def main():

    import argparse
    parser = argparse.ArgumentParser(
        description="Pytorch Image CNN training from Configure Files")
    parser.add_argument(
        '--config_file',
        required=True,
        help="This scripts only accepts parameters from Json files")
    input_args = parser.parse_args()

    config_file = input_args.config_file

    args = parse_config(config_file)
    if args.name is None:
        args.name = get_stem(config_file)

    torch.set_default_tensor_type('torch.FloatTensor')
    best_prec1 = 0

    args.script_name = get_stem(__file__)
    current_time_str = get_date_str()
    if args.save_directory is None:
        save_directory = get_dir(
            os.path.join(project_root, args.ckpts_dir,
                         '{:s}'.format(args.name),
                         '{:s}-{:s}'.format(args.ID, current_time_str)))
    else:
        save_directory = get_dir(
            os.path.join(project_root, args.ckpts_dir, args.save_directory))

    print("Save to {}".format(save_directory))
    log_file = os.path.join(save_directory,
                            'log-{0}.txt'.format(current_time_str))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)

    print_func = logger.info
    print_func('ConfigFile: {}'.format(config_file))
    args.log_file = log_file

    if args.device:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    if args.pretrained:
        print_func("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           num_classes=args.num_classes)
    else:
        print_func("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=False,
                                           num_classes=args.num_classes)

    if args.freeze:
        model = CNN_utils.freeze_all_except_fc(model)

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        print_func(
            'Please only specify one GPU since we are working in batch size 1 model'
        )
        return

    if args.resume:
        if os.path.isfile(args.resume):
            print_func("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            import collections
            if not args.evaluate:
                if isinstance(checkpoint, collections.OrderedDict):
                    load_state_dict(model,
                                    checkpoint,
                                    exclude_layers=['fc.weight', 'fc.bias'])

                else:
                    load_state_dict(
                        model,
                        checkpoint['state_dict'],
                        exclude_layers=['module.fc.weight', 'module.fc.bias'])
                    print_func("=> loaded checkpoint '{}' (epoch {})".format(
                        args.resume, checkpoint['epoch']))
            else:
                if isinstance(checkpoint, collections.OrderedDict):
                    load_state_dict(model, checkpoint, strict=True)

                else:
                    load_state_dict(model,
                                    checkpoint['state_dict'],
                                    strict=True)
                    print_func("=> loaded checkpoint '{}' (epoch {})".format(
                        args.resume, checkpoint['epoch']))
        else:
            print_func("=> no checkpoint found at '{}'".format(args.resume))
            return
    else:
        print_func(
            "=> This script is for fine-tuning only, please double check '{}'".
            format(args.resume))
        print_func("Now using randomly initialized parameters!")

    cudnn.benchmark = True

    model_total_params = sum(p.numel() for p in model.parameters())
    model_grad_params = sum(p.numel() for p in model.parameters()
                            if p.requires_grad)
    print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format(
        model_total_params, model_grad_params))

    # Data loading code
    # val_dataset = get_instance(custom_datasets, '{0}'.format(args.valloader), args)
    from PyUtils.pickle_utils import loadpickle
    from torchvision.datasets.folder import default_loader

    val_dataset = loadpickle(args.val_file)
    image_directory = args.data_dir
    from CNNs.datasets.multilabel import get_val_simple_transform
    val_transform = get_val_simple_transform()
    import tqdm
    import numpy as np

    if args.individual_feat:
        feature_save_directory = get_dir(
            os.path.join(save_directory, 'individual-features'))
        created_paths = set()
    else:
        data_dict = {}
        feature_save_directory = os.path.join(save_directory, 'feature.pkl')

    model.eval()

    for s_data in tqdm.tqdm(val_dataset, desc="Extracting Features"):
        if s_data is None:
            continue

        image_path = os.path.join(image_directory, s_data[0])

        try:
            input_image = default_loader(image_path)
        except:
            print("WARN: {} Problematic!, Skip!".format(image_path))

            continue

        input_image = val_transform(input_image)

        if args.gpu is not None:
            input_image = input_image.cuda(args.gpu, non_blocking=True)

        output = model(input_image.unsqueeze_(0))
        output = output.cpu().data.numpy()
        # image_rel_path = os.path.join(*(s_image_name.split(os.sep)[-int(args.rel_path_depth):]))

        if args.individual_feat:
            if image_directory in created_paths:
                np.save(
                    os.path.join(feature_save_directory,
                                 '{}.npy'.format(s_data[0])), output)
            else:
                get_dir(os.path.join(feature_save_directory, image_directory))
                np.save(
                    os.path.join(feature_save_directory,
                                 '{}.npy'.format(s_data[0])), output)
                created_paths.add(image_directory)
        else:
            data_dict[s_data[0]] = output
        # image_name = os.path.basename(s_image_name)
        #
        # if args.individual_feat:
        #         # image_name = os.path.basename(s_image_name)
        #
        #         np.save(os.path.join(feature_save_directory, '{}.npy'.format(image_name)), output)
        #         # created_paths.add(image_directory)
        # else:
        #         data_dict[get_stem(image_name)] = output

    if args.individual_feat:
        print_func("Done")
    else:
        from PyUtils.pickle_utils import save2pickle
        print_func("Saving to a single big file!")

        save2pickle(feature_save_directory, data_dict)
        print_func("Done")
Ejemplo n.º 7
0
def main():

    import argparse
    parser = argparse.ArgumentParser(
        description="Pytorch Image CNN training from Configure Files")
    parser.add_argument(
        '--config_file',
        required=True,
        help="This scripts only accepts parameters from Json files")
    input_args = parser.parse_args()

    config_file = input_args.config_file

    args = parse_config(config_file)
    if args.name is None:
        args.name = get_stem(config_file)

    torch.set_default_tensor_type('torch.FloatTensor')

    args.script_name = get_stem(__file__)
    current_time_str = get_date_str()
    if args.resume is None:
        if args.save_directory is None:
            save_directory = get_dir(
                os.path.join(project_root, 'ckpts', '{:s}'.format(args.name),
                             '{:s}-{:s}'.format(args.ID, current_time_str)))
        else:
            save_directory = get_dir(
                os.path.join(project_root, 'ckpts', args.save_directory))
    else:
        if args.save_directory is None:
            save_directory = os.path.dirname(args.resume)
        else:
            current_time_str = get_date_str()
            save_directory = get_dir(
                os.path.join(args.save_directory, '{:s}'.format(args.name),
                             '{:s}-{:s}'.format(args.ID, current_time_str)))
    print("Save to {}".format(save_directory))
    log_file = os.path.join(save_directory,
                            'log-{0}.txt'.format(current_time_str))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)

    print_func = logger.info
    print_func('ConfigFile: {}'.format(config_file))
    args.log_file = log_file

    if args.device:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    #args.distributed = args.world_size > 1
    args.distributed = False

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    num_datasets = args.num_datasets
    # model_list = [None for x in range(num_datasets)]
    # for j in range(num_datasets):
    if args.pretrained:
        print_func("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           num_classes=args.class_len)
    else:
        print_func("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=False,
                                           num_classes=args.class_len)

    if args.freeze:
        model = CNN_utils.freeze_all_except_fc(model)

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # # define loss function (criterion) and optimizer
    # # # Update: here
    # # config = {'loss': {'type': 'simpleCrossEntropyLoss', 'args': {'param': None}}}
    # # criterion = get_instance(loss_funcs, 'loss', config)
    # # criterion = criterion.cuda(args.gpu)
    #
    criterion = nn.CrossEntropyLoss(ignore_index=-1).cuda(args.gpu)
    # criterion = MclassCrossEntropyLoss().cuda(args.gpu)

    # params = list()
    # for j in range(num_datasets):
    #     params += list(model_list[j].parameters())

    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.lr_schedule:
        print_func("Using scheduled learning rate")
        scheduler = lr_scheduler.MultiStepLR(
            optimizer, [int(i) for i in args.lr_schedule.split(',')],
            gamma=0.1)
    else:
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   'min',
                                                   patience=args.lr_patience)
    '''
    if args.resume:
        if os.path.isfile(args.resume):
            print_func("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            import collections
            if not args.evaluate:
                if isinstance(checkpoint, collections.OrderedDict):
                    load_state_dict(model, checkpoint, exclude_layers=['fc.weight', 'fc.bias'])


                else:
                    load_state_dict(model, checkpoint['state_dict'], exclude_layers=['module.fc.weight', 'module.fc.bias'])
                    print_func("=> loaded checkpoint '{}' (epoch {})"
                          .format(args.resume, checkpoint['epoch']))
            else:
                if isinstance(checkpoint, collections.OrderedDict):
                    load_state_dict(model, checkpoint, strict=True)
                else:
                    load_state_dict(model, checkpoint['state_dict'], strict=True)
                    print_func("=> loaded checkpoint '{}' (epoch {})"
                               .format(args.resume, checkpoint['epoch']))
        else:
            print_func("=> no checkpoint found at '{}'".format(args.resume))
            return
    '''

    cudnn.benchmark = True

    model_total_params = sum(p.numel() for p in model.parameters())
    model_grad_params = sum(p.numel() for p in model.parameters()
                            if p.requires_grad)
    print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format(
        model_total_params, model_grad_params))

    # Data loading code
    val_loaders = [None for x in range(num_datasets)]
    test_loaders = [None for x in range(num_datasets)]
    train_loaders = [None for x in range(num_datasets)]
    num_iter = 0
    for k in range(num_datasets):
        args.ind = k

        val_dataset = get_instance(custom_datasets, args.val_loader, args)
        if val_dataset is None or k == num_datasets - 1:
            val_loaders[args.ind] = None
        else:
            val_loaders[args.ind] = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                collate_fn=none_collate)

        if hasattr(args, 'test_files') and hasattr(args, 'test_loader'):
            test_dataset = get_instance(custom_datasets, args.test_loader,
                                        args)
            test_loaders[args.ind] = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True,
                collate_fn=none_collate)
        else:
            # test_dataset = None
            test_loaders[args.ind] = None

        #if args.evaluate:
        #    validate(test_loaders[args.ind], model_list[k], criterion, args, print_func)
        #    return
        # if not args.evaluate: #else:
        #     train_samplers = [None for x in range(num_datasets)]
        #     train_dataset = get_instance(custom_datasets, args.train_loader, args)
        #
        #     if args.distributed:
        #         train_samplers[args.ind] = torch.utils.data.distributed.DistributedSampler(train_dataset)
        #     else:
        #         train_samplers[args.ind] = None
        #
        #     train_loaders[args.ind] = torch.utils.data.DataLoader(
        #         train_dataset, batch_size=args.batch_size, shuffle=(train_samplers[args.ind] is None),
        #         num_workers=args.workers, pin_memory=True, sampler=train_samplers[args.ind], collate_fn=none_collate)
        if not args.evaluate:  #else:
            # train_samplers = [None for x in range(num_datasets)]
            train_dataset = get_instance(custom_datasets, args.train_loader,
                                         args)

            num_iter = max(num_iter, len(train_dataset.samples))
            if args.distributed:
                train_samplers = torch.utils.data.distributed.DistributedSampler(
                    train_dataset)
            else:
                train_samplers = None

            train_loaders[args.ind] = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=train_samplers is None,
                num_workers=args.workers,
                pin_memory=True,
                sampler=train_samplers,
                collate_fn=none_collate)
    setattr(args, 'num_iter', num_iter)

    # TRAINING
    best_prec1 = [-1 for _ in range(num_datasets)]
    is_best = [None for _ in range(num_datasets)]
    setattr(args, 'lam', 0.5)

    start_data_time = time.time()
    train_loads_iters = [iter(train_loaders[x]) for x in range(num_datasets)]
    print_func("Loaded data in {:.3f} s".format(time.time() - start_data_time))
    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:
            for x in range(num_datasets):
                train_samplers[x].set_epoch(epoch)
        if args.lr_schedule:
            # CNN_utils.adjust_learning_rate(optimizer, epoch, args.lr)
            scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        print_func("Epoch: [{}], learning rate: {}".format(epoch, current_lr))

        # train for one epoch
        train(train_loads_iters, train_loaders, model, criterion, optimizer,
              epoch, args, print_func)

        # evaluate and save
        val_prec1 = [None for x in range(num_datasets)]
        test_prec1 = [None for x in range(num_datasets)]
        for j in range(num_datasets):
            # if j != args.ind:
            #     load_state_dict(model_list[j], model_list[args.ind].state_dict())
            # evaluate on validation set
            if val_loaders[j]:
                val_prec1[j], _ = validate(val_loaders[j], model, criterion,
                                           args, print_func, j)
            else:
                val_prec1[j] = 0
            # remember best prec@1 and save checkpoint
            is_best[j] = val_prec1[j] > best_prec1[j]
            best_prec1[j] = max(val_prec1[j], best_prec1[j])

            if is_best[j]:
                save_ind = j
            else:
                save_ind = "#"
            CNN_utils.save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1[j],
                    'optimizer': optimizer.state_dict(),
                },
                is_best[j],
                file_directory=save_directory,
                epoch=epoch,
                save_best_only=args.save_best_only,
                ind=save_ind)

            test_prec1[j], _ = validate(test_loaders[j],
                                        model,
                                        criterion,
                                        args,
                                        print_func,
                                        j,
                                        phase='Test')

        print_func("Val precisions: {}".format(val_prec1))
        print_func("Test precisions: {}".format(test_prec1))
def main():
    best_prec1 = 0

    args = parser.parser.parse_args()
    config_file = None
    if args.config is not None:
        config_file = args.config
        args = parse_config(args.config)

    script_name_stem = get_stem(__file__)
    current_time_str = get_date_str()
    if args.resume is None:
        if args.save_directory is None:
            save_directory = get_dir(
                os.path.join(
                    project_root, 'ckpts', '{:s}'.format(args.name),
                    '{:s}-{:s}-{:s}'.format(script_name_stem, args.ID,
                                            current_time_str)))
        else:
            save_directory = get_dir(
                os.path.join(project_root, 'ckpts', args.save_directory))
    else:
        save_directory = os.path.dirname(args.resume)
    print("Save to {}".format(save_directory))
    log_file = os.path.join(save_directory,
                            'log-{0}.txt'.format(current_time_str))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)

    print_func = logger.info
    if config_file is not None:
        print_func('ConfigFile: {}'.format(config_file))
    else:
        print_func('ConfigFile: None, params from argparse')

    args.log_file = log_file

    if args.device:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.arch == 'resnet50otherinits':
        print_func("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           num_classes=args.num_classes,
                                           param_name=args.paramname)
    elif args.arch == 'resnet50_feature_extractor':
        print_func("=> using pre-trained model '{}' to LOAD FEATURES".format(
            args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           num_classes=args.num_classes,
                                           param_name=args.paramname)

    else:
        if args.pretrained:
            print_func("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True,
                                               num_classes=args.num_classes)
        else:
            print_func("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=False,
                                               num_classes=args.num_classes)

    if args.freeze:
        model = CNN_utils.freeze_all_except_fc(model)

    if args.gpu is not None:
        model = model.cuda(args.gpu)
    elif args.distributed:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    # # Update: here
    # config = {'loss': {'type': 'simpleCrossEntropyLoss', 'args': {'param': None}}}
    # criterion = get_instance(loss_funcs, 'loss', config)
    # criterion = criterion.cuda(args.gpu)

    criterion = nn.CrossEntropyLoss(ignore_index=-1).cuda(args.gpu)

    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                       model.parameters()),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.lr_schedule:
        print_func("Using scheduled learning rate")
        scheduler = lr_scheduler.MultiStepLR(
            optimizer, [int(i) for i in args.lr_schedule.split(',')],
            gamma=0.1)
    else:
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   'min',
                                                   patience=args.lr_patience)

    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_func("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_func("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print_func("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    model_total_params = sum(p.numel() for p in model.parameters())
    model_grad_params = sum(p.numel() for p in model.parameters()
                            if p.requires_grad)
    print_func("Total Parameters: {0}\t Gradient Parameters: {1}".format(
        model_total_params, model_grad_params))

    # Data loading code
    val_dataset = get_instance(custom_datasets,
                               '{0}_val'.format(args.dataset.name), args,
                               **args.dataset.args)
    if val_dataset is None:
        val_loader = None
    else:
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 collate_fn=none_collate)

    if args.evaluate:
        validate(val_loader, model, criterion, args, print_func)
        return
    else:

        train_dataset = get_instance(custom_datasets,
                                     '{0}_train'.format(args.dataset.name),
                                     args, **args.dataset.args)

        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset)
        else:
            train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler,
            collate_fn=none_collate)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        if args.lr_schedule:
            # CNN_utils.adjust_learning_rate(optimizer, epoch, args.lr)
            scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        print_func("Epoch: [{}], learning rate: {}".format(epoch, current_lr))

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              print_func)

        # evaluate on validation set
        if val_loader:
            prec1, val_loss = validate(val_loader, model, criterion, args,
                                       print_func)
        else:
            prec1 = 0
            val_loss = 0
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        CNN_utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            file_directory=save_directory,
            epoch=epoch)

        if not args.lr_schedule:
            scheduler.step(val_loss)
Ejemplo n.º 9
0
def main():
    global args
    args = (parser.parse_args())
    use_cuda = cuda_model.ifUseCuda(args.gpu_id, args.multiGpu)
    script_name_stem = dir_utils.get_stem(__file__)
    save_directory = dir_utils.get_dir(
        os.path.join(
            project_root, 'ckpts',
            '{:s}-assgin{:.2f}-alpha{:.4f}-dim{:d}-dropout{:.4f}-seqlen{:d}-{:s}-{:s}'
            .format(script_name_stem, args.hassign_thres, args.alpha,
                    args.hidden_dim, args.dropout, args.seq_len,
                    loss_type[args.EMD], match_type[args.hmatch])))
    log_file = os.path.join(save_directory,
                            'log-{:s}.txt'.format(dir_utils.get_date_str()))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)

    model = BaseLSTMNetwork(input_dim=args.input_dim,
                            embedding_dim=args.embedding_dim,
                            hidden_dim=args.hidden_dim,
                            max_decoding_len=args.net_outputs,
                            dropout=args.dropout,
                            n_enc_layers=2)
    hassign_thres = args.hassign_thres
    logger.info("Number of Params\t{:d}".format(
        sum([p.data.nelement() for p in model.parameters()])))
    logger.info('Saving logs to {:s}'.format(log_file))

    if args.resume is not None:

        ckpt_idx = 48

        ckpt_filename = args.resume.format(ckpt_idx)
        assert os.path.isfile(
            ckpt_filename), 'Error: no checkpoint directory found!'

        checkpoint = torch.load(ckpt_filename,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        train_iou = checkpoint['IoU']
        args.start_epoch = checkpoint['epoch']

        logger.info("=> loading checkpoint '{}', current iou: {:.04f}".format(
            ckpt_filename, train_iou))

    model = cuda_model.convertModel2Cuda(model,
                                         gpu_id=args.gpu_id,
                                         multiGpu=args.multiGpu)

    train_dataset = MNIST(dataset_split='train')
    val_dataset = MNIST(dataset_split='val')

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=4)

    model_optim = optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=float(args.lr))
    optim_scheduler = optim.lr_scheduler.ReduceLROnPlateau(model_optim,
                                                           'min',
                                                           patience=10)

    alpha = args.alpha
    # cls_weights = torch.FloatTensor([0.05, 1.0]).cuda()
    for epoch in range(args.start_epoch, args.nof_epoch + args.start_epoch):
        total_losses = AverageMeter()
        loc_losses = AverageMeter()
        cls_losses = AverageMeter()
        Accuracy = AverageMeter()
        IOU = AverageMeter()
        ordered_IOU = AverageMeter()
        model.train()
        pbar = progressbar.ProgressBar(max_value=len(train_dataloader))
        for i_batch, sample_batch in enumerate(train_dataloader):
            pbar.update(i_batch)

            feature_batch = Variable(sample_batch[0])
            labels = Variable(sample_batch[1])

            if use_cuda:
                feature_batch = feature_batch.cuda()
                labels = labels.cuda()
                # end_indices = end_indices.cuda()

            pred_labels = model(feature_batch)

            labels = labels.contiguous().view(-1)
            pred_labels = pred_labels.contiguous().view(
                -1,
                pred_labels.size()[-1])

            pred_probs = F.softmax(pred_labels, dim=1)[:, 1]
            pred_probs[pred_probs > 0.5] = 1
            pred_probs[pred_probs <= 0.5] = -1
            n_positives = torch.sum(labels).item()
            iou = torch.sum(
                pred_probs == labels.float()).item() * 1. / n_positives
            IOU.update(iou, 1.)

            total_loss = F.cross_entropy(pred_labels, labels)

            model_optim.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            model_optim.step()
            # cls_losses.update(cls_loss.data.item(), feature_batch.size(0))
            total_losses.update(total_loss.item(), feature_batch.size(0))

        logger.info(
            "Train -- Epoch :{:06d}, LR: {:.6f},\tloss={:.4f}, \t c-loss:{:.4f}, \tloc-loss:{:.4f}\tcls-Accuracy:{:.4f}\tloc-Avg-IOU:{:.4f}\t topIOU:{:.4f}"
            .format(epoch, model_optim.param_groups[0]['lr'], total_losses.avg,
                    cls_losses.avg, loc_losses.avg, Accuracy.avg, IOU.avg,
                    ordered_IOU.avg))

        optim_scheduler.step(total_losses.avg)

        model.eval()

        IOU = AverageMeter()
        pbar = progressbar.ProgressBar(max_value=len(val_dataloader))
        for i_batch, sample_batch in enumerate(val_dataloader):
            pbar.update(i_batch)

            feature_batch = Variable(sample_batch[0])
            labels = Variable(sample_batch[1])

            if use_cuda:
                feature_batch = feature_batch.cuda()
                labels = labels.cuda()

            labels = labels.contiguous().view(-1)

            pred_labels = model(feature_batch)

            pred_labels = pred_labels.contiguous().view(
                -1,
                pred_labels.size()[-1])

            pred_probs = F.softmax(pred_labels, dim=1)[:, 1]
            n_positives = torch.sum(labels).item()
            pred_probs[pred_probs > 0.5] = 1
            pred_probs[pred_probs <= 0.5] = -1

            iou = torch.sum(
                pred_probs == labels.float()).item() * 1. / n_positives
            IOU.update(iou, 1.)

        logger.info(
            "Val -- Epoch :{:06d}, LR: {:.6f},\tloc-Avg-IOU:{:.4f}".format(
                epoch,
                model_optim.param_groups[0]['lr'],
                IOU.avg,
            ))

        if epoch % 1 == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'loss': total_losses.avg,
                    'cls_loss': cls_losses.avg,
                    'loc_loss': loc_losses.avg,
                    'IoU': IOU.avg
                }, (epoch + 1),
                file_direcotry=save_directory)
def main():
    global args
    args = (parser.parse_args())
    use_cuda = cuda_model.ifUseCuda(args.gpu_id, args.multiGpu)
    script_name_stem = dir_utils.get_stem(__file__)
    save_directory = dir_utils.get_dir(
        os.path.join(
            project_root, 'ckpts',
            '{:s}-{:s}-{:s}-split-{:d}-decoderatio-{:.2f}-alpha{:.4f}-dim{:d}-dropout{:.4f}'
            .format(script_name_stem, args.dataset, args.eval_metrics,
                    args.split, args.decode_ratio, args.alpha, args.hidden_dim,
                    args.dropout)))
    log_file = os.path.join(save_directory,
                            'log-{:s}.txt'.format(dir_utils.get_date_str()))
    logger = log_utils.get_logger(log_file)
    log_utils.print_config(vars(args), logger)

    # get train/val split
    if args.dataset == 'SumMe':
        train_val_perms = np.arange(25)
    elif args.dataset == 'TVSum':
        train_val_perms = np.arange(50)
    # fixed permutation
    random.Random(0).shuffle(train_val_perms)
    train_val_perms = train_val_perms.reshape([5, -1])
    train_perms = np.delete(train_val_perms, args.split, 0).reshape([-1])
    val_perms = train_val_perms[args.split]
    logger.info(" training split: " + str(train_perms))
    logger.info(" val split: " + str(val_perms))

    if args.location == 'home':
        data_path = os.path.join(os.path.expanduser('~'), 'datasets')
    else:
        data_path = os.path.join('/nfs/%s/boyu/SDN' % (args.location),
                                 'datasets')
    train_dataset = vsTVSum_Loader3_c3dd_segment.cDataset(
        dataset_name=args.dataset,
        split='train',
        decode_ratio=args.decode_ratio,
        train_val_perms=train_perms,
        data_path=data_path)
    max_input_len = train_dataset.max_input_len
    maximum_outputs = int(args.decode_ratio * max_input_len)
    val_dataset = vsTVSum_Loader3_c3dd_segment.cDataset(
        dataset_name=args.dataset,
        split='val',
        decode_ratio=args.decode_ratio,
        train_val_perms=val_perms,
        data_path=data_path)
    train_evaluator = Evaluator.Evaluator(dataset_name=args.dataset,
                                          split='tr',
                                          max_input_len=max_input_len,
                                          maximum_outputs=maximum_outputs,
                                          sum_budget=0.15,
                                          train_val_perms=train_perms,
                                          eval_metrics=args.eval_metrics,
                                          data_path=data_path)
    val_evaluator = Evaluator.Evaluator(dataset_name=args.dataset,
                                        split='val',
                                        max_input_len=max_input_len,
                                        maximum_outputs=maximum_outputs,
                                        sum_budget=0.15,
                                        train_val_perms=val_perms,
                                        eval_metrics=args.eval_metrics,
                                        data_path=data_path)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=4)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=4)

    model = PointerNetwork(input_dim=args.input_dim,
                           embedding_dim=args.embedding_dim,
                           hidden_dim=args.hidden_dim,
                           max_decoding_len=maximum_outputs,
                           dropout=args.dropout,
                           n_enc_layers=2,
                           output_classes=1)
    # hassign_thres = args.hassign_thres
    logger.info("Number of Params\t{:d}".format(
        sum([p.data.nelement() for p in model.parameters()])))
    logger.info('Saving logs to {:s}'.format(log_file))

    if args.resume is not None:

        ckpt_idx = 48

        ckpt_filename = args.resume.format(ckpt_idx)
        assert os.path.isfile(
            ckpt_filename), 'Error: no checkpoint directory found!'

        checkpoint = torch.load(ckpt_filename,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        train_iou = checkpoint['IoU']
        args.start_epoch = checkpoint['epoch']

        logger.info("=> loading checkpoint '{}', current iou: {:.04f}".format(
            ckpt_filename, train_iou))

    model = cuda_model.convertModel2Cuda(model,
                                         gpu_id=args.gpu_id,
                                         multiGpu=args.multiGpu)

    model_optim = optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=float(args.lr))
    optim_scheduler = optim.lr_scheduler.ReduceLROnPlateau(model_optim,
                                                           'min',
                                                           patience=10)

    alpha = args.alpha
    # cls_weights = torch.FloatTensor([0.2, 1.0]).cuda()
    # if args.set_cls_weight:
    #     cls_weights = torch.FloatTensor([1.*train_dataset.n_positive_train_samples/train_dataset.n_total_train_samples, args.cls_pos_weight]).cuda()
    # else:
    #     cls_weights = torch.FloatTensor([0.5, 0.5]).cuda()
    # logger.info(" total: {:d}, total pos: {:d}".format(train_dataset.n_total_train_samples, train_dataset.n_positive_train_samples))
    # logger.info(" classify weight: " + str(cls_weights[0]) + str(cls_weights[1]))
    for epoch in range(args.start_epoch, args.nof_epoch + args.start_epoch):
        total_losses = AverageMeter()
        # loc_losses = AverageMeter()
        pointer_losses = AverageMeter()
        rgs_losses = AverageMeter()
        Accuracy = AverageMeter()
        # IOU = AverageMeter()
        # ordered_IOU = AverageMeter()
        model.train()
        pbar = progressbar.ProgressBar(max_value=len(train_dataloader))
        for i_batch, sample_batch in enumerate(train_dataloader):
            pbar.update(i_batch)

            feature_batch = Variable(sample_batch[0])
            pointer_indices = Variable(sample_batch[1])
            pointer_scores = Variable(sample_batch[2])
            gt_valids = Variable(sample_batch[3])
            # seq_labels = Variable(sample_batch[3])

            if use_cuda:
                feature_batch = feature_batch.cuda()
                pointer_indices = pointer_indices.cuda()
                pointer_scores = pointer_scores.cuda()

            gt_positions = pointer_indices
            gt_scores = pointer_scores

            pointer_probs, pointer_positions, cls_scores, _ = model(
                feature_batch)

            pred_positions = pointer_positions

            cls_scores = cls_scores.contiguous().squeeze(2)

            # print(pointer_probs.size())
            # print(gt_positions.size())
            pointer_loss = F.cross_entropy(pointer_probs.permute(0, 2, 1),
                                           gt_positions)

            rgs_loss = F.mse_loss(cls_scores, gt_scores)

            total_loss = alpha * pointer_loss + rgs_loss

            model_optim.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            model_optim.step()
            pointer_losses.update(pointer_loss.data.item(),
                                  feature_batch.size(0))
            rgs_losses.update(rgs_loss.data.item(), feature_batch.size(0))
            total_losses.update(total_loss.item(), feature_batch.size(0))

        logger.info(
            "Train -- Epoch :{:06d}, LR: {:.6f},\tloss={:.4f}, \t pointer-loss:{:.4f}, \tregress-loss:{:.4f}\tcls-Accuracy:{:.4f}"
            .format(epoch, model_optim.param_groups[0]['lr'], total_losses.avg,
                    pointer_losses.avg, rgs_losses.avg, Accuracy.avg))

        optim_scheduler.step(total_losses.avg)

        model.eval()

        pointer_losses = AverageMeter()
        rgs_losses = AverageMeter()
        pbar = progressbar.ProgressBar(max_value=len(val_dataloader))
        for i_batch, sample_batch in enumerate(val_dataloader):
            pbar.update(i_batch)

            feature_batch = Variable(sample_batch[0])
            pointer_indices = Variable(sample_batch[1])
            pointer_scores = Variable(sample_batch[2])
            gt_valids = Variable(sample_batch[3])
            # valid_indices = Variable(sample_batch[3])

            if use_cuda:
                feature_batch = feature_batch.cuda()
                pointer_indices = pointer_indices.cuda()
                pointer_scores = pointer_scores.cuda()

            gt_positions = pointer_indices
            gt_scores = pointer_scores

            pointer_probs, pointer_positions, cls_scores, _ = model(
                feature_batch)

            pred_positions = pointer_positions

            cls_scores = cls_scores.contiguous().squeeze(2)

            pointer_loss = F.cross_entropy(pointer_probs.permute(0, 2, 1),
                                           gt_positions)

            rgs_loss = F.mse_loss(cls_scores, gt_scores)

            pointer_losses.update(pointer_loss.data.item(),
                                  feature_batch.size(0))
            rgs_losses.update(rgs_loss.data.item(), feature_batch.size(0))

        train_F1s = train_evaluator.Evaluate(model)
        val_F1s = val_evaluator.Evaluate(model)

        logger.info("Train -- Epoch :{:06d},\tF1s:{:.4f}".format(
            epoch, train_F1s))

        logger.info(
            "Val -- Epoch :{:06d},\t pointer-loss:{:.4f}, \tregress-loss:{:.4f}, \tF1s{:.4f}"
            .format(epoch, pointer_losses.avg, rgs_losses.avg, val_F1s))

        if epoch % 1 == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'loss': total_losses.avg,
                    'pointer_loss': pointer_losses.avg,
                    'rgs_loss': rgs_losses.avg,
                    'val_F1s': val_F1s
                }, (epoch + 1),
                file_direcotry=save_directory)