def fine_tune_train_and_val(args, recorder):
    # =
    global lowest_val_loss, best_prec1
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # close the warning
    torch.manual_seed(1)
    cudnn.benchmark = True
    timer = Timer()
    # == dataset config==
    num_class, data_length, image_tmpl = ft_data_config(args)
    train_transforms, test_transforms, eval_transforms = ft_augmentation_config(
        args)
    train_data_loader, val_data_loader, _, _, _, _ = ft_data_loader_init(
        args, data_length, image_tmpl, train_transforms, test_transforms,
        eval_transforms)
    # == model config==
    model = ft_model_config(args, num_class)
    recorder.record_message('a', '=' * 100)
    recorder.record_message('a', '-' * 40 + 'finetune' + '-' * 40)
    recorder.record_message('a', '=' * 100)
    # == optim config==
    train_criterion, val_criterion, optimizer = ft_optim_init(args, model)
    # == data augmentation(self-supervised) config==
    tc = TC(args)
    # == train and eval==
    print('*' * 70 + 'Step2: fine tune' + '*' * 50)
    for epoch in range(args.ft_start_epoch, args.ft_epochs):
        timer.tic()
        ft_adjust_learning_rate(optimizer, args.ft_lr, epoch, args.ft_lr_steps)
        train_prec1, train_loss = train(args, tc, train_data_loader, model,
                                        train_criterion, optimizer, epoch,
                                        recorder)
        # train_prec1, train_loss = random.random() * 100, random.random()
        recorder.record_ft_train(train_loss / 5.0, train_prec1 / 100.0)
        if (epoch + 1) % args.ft_eval_freq == 0:
            val_prec1, val_loss = validate(args, tc, val_data_loader, model,
                                           val_criterion, recorder)
            # val_prec1, val_loss = random.random() * 100, random.random()
            recorder.record_ft_val(val_loss / 5.0, val_prec1 / 100.0)
            is_best = val_prec1 > best_prec1
            best_prec1 = max(val_prec1, best_prec1)
            checkpoint = {
                'epoch': epoch + 1,
                'arch': "i3d",
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            }
        recorder.save_ft_model(checkpoint, is_best)
        timer.toc()
        left_time = timer.average_time * (args.ft_epochs - epoch)
        message = "Step2: fine tune best_prec1 is: {} left time is : {} now is : {}".format(
            best_prec1, timer.format(left_time), datetime.now())
        print(message)
        recorder.record_message('a', message)
    return recorder.filename
Esempio n. 2
0
def pretext_train(args, recorder):
    if args.gpus is not None:
        print("Use GPU: {} for pretext training".format(args.gpus))
    num_class, data_length, image_tmpl = pt_data_config(args)
    # print("tp_length is: ", data_length)
    train_transforms, test_transforms, eval_transforms = pt_augmentation_config(
        args)
    train_loader, val_loader, eval_loader, train_samples, val_samples, eval_samples = pt_data_loader_init(
        args, data_length, image_tmpl, train_transforms, test_transforms,
        eval_transforms)

    n_data = len(train_loader)

    model, model_ema = pt_model_config(args, num_class)
    # == optim config==
    contrast, criterion, optimizer = pt_optim_init(args, model, n_data)
    model = model.cuda()
    # == load weights ==
    model, model_ema = pt_load_weight(args, model, model_ema, optimizer,
                                      contrast)
    if args.pt_method in ['dsm', 'moco']:
        model_ema = model_ema.cuda()
        # copy weights from `model' to `model_ema'
        moment_update(model, model_ema, 0)
    cudnn.benchmark = True
    # optionally resume from a checkpoint
    args.start_epoch = 1

    # ==================================== our data augmentation method=================================
    if args.pt_method in ['dsm', 'dsm_triplet']:
        pos_aug = GenPositive()
        neg_aug = GenNegative()

    # =======================================add message =====================
    recorder.record_message('a', '=' * 100)
    recorder.record_message('a', '-' * 40 + 'pretrain' + '-' * 40)
    recorder.record_message('a', '=' * 100)
    # ====================update lr_decay from str to numpy=========
    iterations = args.pt_lr_decay_epochs.split(',')
    args.pt_lr_decay_epochs = list([])
    for it in iterations:
        args.pt_lr_decay_epochs.append(int(it))
    timer = Timer()
    # routine
    print('*' * 70 + 'Step1: pretrain' + '*' * 20 + '*' * 50)
    for epoch in range(args.pt_start_epoch, args.pt_epochs + 1):
        timer.tic()
        pt_adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        if args.pt_method == "moco":
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args,
                                    recorder)
        elif args.pt_method == "dsm":
            loss, prob = train_dsm(epoch, train_loader, model, model_ema,
                                   contrast, criterion, optimizer, args,
                                   pos_aug, neg_aug, recorder)
        # loss, prob = epoch * 0.01, 0.02*epoch
        elif args.pt_method == "dsm_triplet":
            loss = train_dsm_triplet(epoch, train_loader, model, optimizer,
                                     args, pos_aug, neg_aug, recorder)
        else:
            Exception("Not support method now!")
        recorder.record_pt_train(loss)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        timer.toc()
        left_time = timer.average_time * (args.pt_epochs - epoch)
        message = "Step1: pretrain now loss is: {} left time is : {} now is: {}".format(
            loss, timer.format(left_time), datetime.now())
        print(message)
        recorder.record_message('a', message)
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        recorder.save_pt_model(args, state, epoch)
    print("finished pretrain, the trained model is record in: {}".format(
        recorder.pt_checkpoint))
    return recorder.pt_checkpoint
Esempio n. 3
0
def inference_mean_exemplar(
    model,
    current_epoch,
    current_iter,
    local_rank,
    data_loader,
    dataset_name,
    device="cuda",
    max_instance=3200,
    mute=False,
):
    model.train(False)
    # convert to a torch.device for efficiency
    device = torch.device(device)
    if not mute:
        logger = logging.getLogger("maskrcnn_benchmark.inference")
        logger.info("Start evaluation")
    total_timer = Timer()
    inference_timer = Timer()
    total_timer.tic()
    torch.cuda.empty_cache()
    if not mute:
        pbar = tqdm(total=len(data_loader), desc="Validation in progress")
    with torch.no_grad():
        all_pred_obj, all_truth_obj, all_pred_attr, all_truth_attr = [], [], [], []
        obj_loss_all, attr_loss_all = 0, 0
        cnt = 0
        for iteration, out_dict in enumerate(data_loader):
            if type(max_instance) is int:
                if iteration == max_instance // model.cfg.EXTERNAL.BATCH_SIZE:
                    break
            if type(max_instance) is float:
                if iteration > max_instance * len(
                        data_loader) // model.cfg.EXTERNAL.BATCH_SIZE:
                    break
            # print(iteration)
            images = torch.stack(out_dict['images'])
            obj_labels = torch.cat(out_dict['object_labels'], -1)
            attr_labels = torch.cat(out_dict['attribute_labels'], -1)
            cropped_image = torch.stack(out_dict['cropped_image'])

            images = images.to(device)
            obj_labels = obj_labels.to(device)
            attr_labels = attr_labels.to(device)

            cropped_image = cropped_image.to(device)
            # loss_dict = model(images, targets)
            pred_obj = model.mean_of_exemplar_classify(cropped_image)

            all_pred_obj.extend(to_list(pred_obj))
            all_truth_obj.extend(to_list(obj_labels))
            cnt += 1
            if not mute:
                pbar.update(1)

        obj_f1 = f1_score(all_truth_obj, all_pred_obj, average='micro')
        #attr_f1 = f1_score(all_truth_attr, all_pred_attr, average='micro')
        obj_loss_all /= (cnt + 1e-10)
    # wait for all processes to complete before measuring the time
    total_time = total_timer.toc()
    model.train(True)
    return obj_f1, 0, len(all_truth_obj)
Esempio n. 4
0
def inference(model,
              current_epoch,
              current_iter,
              local_rank,
              data_loader,
              dataset_name,
              device="cuda",
              max_instance=3200,
              mute=False,
              verbose_return=False):
    model.train(False)
    # convert to a torch.device for efficiency
    device = torch.device(device)
    if not mute:
        logger = logging.getLogger("maskrcnn_benchmark.inference")
        logger.info("Start evaluation")
    total_timer = Timer()
    total_timer.tic()
    torch.cuda.empty_cache()
    if not mute:
        pbar = tqdm(total=len(data_loader), desc="Validation in progress")

    def to_list(tensor):
        return tensor.cpu().numpy().tolist()

    with torch.no_grad():
        all_pred_obj, all_truth_obj, all_pred_attr, all_truth_attr = [], [], [], []
        all_image_ids, all_boxes = [], []
        all_pred_attr_prob = []
        all_raws = []
        obj_loss_all, attr_loss_all = 0, 0

        cnt = 0
        for iteration, out_dict in enumerate(data_loader):
            if type(max_instance) is int:
                if iteration == max_instance // model.cfg.EXTERNAL.BATCH_SIZE:
                    break
            if type(max_instance) is float:
                if iteration > max_instance * len(
                        data_loader) // model.cfg.EXTERNAL.BATCH_SIZE:
                    break
            # print(iteration)

            if verbose_return:
                all_image_ids.extend(out_dict['image_ids'])
                all_boxes.extend(out_dict['gt_bboxes'])
                all_raws.extend(out_dict['raw'])

            ret_dict = inference_step(model, out_dict, device)
            loss_attr, loss_obj, attr_score, obj_score = ret_dict.get('attr_loss', None), \
                                                         ret_dict.get('obj_loss', None), \
                                                         ret_dict.get('attr_score', None), \
                                                         ret_dict.get('obj_score', None)

            if loss_attr is not None:
                attr_loss_all += loss_attr.item()
                pred_attr_prob, pred_attr = ret_dict[
                    'pred_attr_prob'], ret_dict['pred_attr']
                all_pred_attr.extend(to_list(pred_attr))
                all_truth_attr.extend(to_list(ret_dict['attr_labels']))
                all_pred_attr_prob.extend(to_list(pred_attr_prob))
            if loss_obj is not None:
                obj_loss_all += loss_obj.item()
                _, pred_obj = obj_score.max(-1)
                all_pred_obj.extend(to_list(pred_obj))
                all_truth_obj.extend(to_list(ret_dict['obj_labels']))
            cnt += 1
            if not mute:
                pbar.update(1)

        obj_f1 = f1_score(all_truth_obj, all_pred_obj, average='micro')
        attr_f1 = f1_score(all_truth_attr, all_pred_attr, average='micro')
        obj_loss_all /= (cnt + 1e-10)
        attr_loss_all /= (cnt + 1e-10)
        if not mute:
            logger.info(
                'Epoch: {}\tIteration: {}\tObject f1: {}\tAttr f1:{}\tObject loss:{}\tAttr loss:{}'
                .format(current_epoch, current_iter, obj_f1, attr_f1,
                        obj_loss_all, attr_loss_all))
        #compute_on_dataset(model, data_loader, local_rank, device, inference_timer, output_file)
    # wait for all processes to complete before measuring the time
    total_time = total_timer.toc()
    model.train(True)
    if not verbose_return:
        return obj_f1, attr_f1, len(all_truth_attr)
    else:
        return obj_f1, attr_f1, all_pred_attr, all_truth_attr, all_pred_obj, all_truth_obj, all_image_ids, all_boxes, \
               all_pred_attr_prob, all_raws
Esempio n. 5
0
def train_and_eval(args):
    # =
    global lowest_val_loss, best_prec1
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # close the warning
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    torch.manual_seed(1)
    cudnn.benchmark = True
    timer = Timer()
    recorder = Record(args)
    # == dataset config==
    num_class, data_length, image_tmpl = data_config(args)
    train_transforms, test_transforms = augmentation_config(args)
    train_data_loader, val_data_loader = data_loader_init(
        args, data_length, image_tmpl, train_transforms, test_transforms)
    # == model config==
    models = []
    optimizers = []
    for i in range(args.mutual_num):
        model = model_config(args, num_class)
        models.append(model)
    recorder.record_message('a', '=' * 100)
    recorder.record_message('a', str(model.module))
    recorder.record_message('a', '=' * 100)
    # == optim config==
    for i in range(args.mutual_num):
        train_criterion, val_criterion, optimizer = optim_init(args, model)
        optimizers.append(optimizer)
    # == data augmentation(self-supervised) config==
    tc = TC(args)
    # == train and eval==
    for epoch in range(args.start_epoch, args.epochs):
        timer.tic()
        for i in range(args.mutual_num):
            adjust_learning_rate(optimizers[i], args.lr, epoch, args.lr_steps)
        if args.eval_indict == 'acc':
            train_prec1, train_loss = train(args, tc, train_data_loader,
                                            models, train_criterion,
                                            optimizers, epoch, recorder)
            # train_prec1, train_loss = random.random() * 100, random.random()
            recorder.record_train(train_loss / 5.0, train_prec1 / 100.0)
        else:
            train_loss = train(args, tc, train_data_loader, models,
                               train_criterion, optimizers, epoch, recorder)
            # train_prec1, train_loss = random.random() * 100, random.random()
            recorder.record_train(train_loss)
        if (epoch + 1) % args.eval_freq == 0:
            if args.eval_indict == 'acc':
                val_prec1, val_loss = validate(args, tc, val_data_loader,
                                               models, val_criterion, recorder)
                # val_prec1, val_loss = random.random() * 100, random.random()
                recorder.record_val(val_loss / 5.0, val_prec1 / 100.0)
                is_best = val_prec1 > best_prec1
                best_prec1 = max(val_prec1, best_prec1)
                checkpoint = {
                    'epoch': epoch + 1,
                    'arch': "i3d",
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1
                }
            else:
                val_loss = validate(args, tc, val_data_loader, models,
                                    val_criterion, recorder)
                # val_loss = random.random()
                # val_prec1, val_loss = random.random() * 100, random.random()
                recorder.record_val(val_loss)
                is_best = val_loss < lowest_val_loss
                lowest_val_loss = min(val_loss, lowest_val_loss)
                checkpoint = {
                    'epoch': epoch + 1,
                    'arch': "i3d",
                    'state_dict': model.state_dict(),
                    'lowest_val': lowest_val_loss
                }
        recorder.save_model(checkpoint, is_best)
        timer.toc()
        left_time = timer.average_time * (args.epochs - epoch)

        if args.eval_indict == 'acc':
            message = "best_prec1 is: {} left time is : {}".format(
                best_prec1, timer.format(left_time))
        else:
            message = "lowest_val_loss is: {} left time is : {}".format(
                lowest_val_loss, timer.format(left_time))
        print(message)
        recorder.record_message('a', message)
    # return recorder.best_name
    return recorder.filename