def fine_tune_train_and_val(args, recorder):
    # =
    global lowest_val_loss, best_prec1
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # close the warning
    torch.manual_seed(1)
    cudnn.benchmark = True
    timer = Timer()
    # == dataset config==
    num_class, data_length, image_tmpl = ft_data_config(args)
    train_transforms, test_transforms, eval_transforms = ft_augmentation_config(
        args)
    train_data_loader, val_data_loader, _, _, _, _ = ft_data_loader_init(
        args, data_length, image_tmpl, train_transforms, test_transforms,
        eval_transforms)
    # == model config==
    model = ft_model_config(args, num_class)
    recorder.record_message('a', '=' * 100)
    recorder.record_message('a', '-' * 40 + 'finetune' + '-' * 40)
    recorder.record_message('a', '=' * 100)
    # == optim config==
    train_criterion, val_criterion, optimizer = ft_optim_init(args, model)
    # == data augmentation(self-supervised) config==
    tc = TC(args)
    # == train and eval==
    print('*' * 70 + 'Step2: fine tune' + '*' * 50)
    for epoch in range(args.ft_start_epoch, args.ft_epochs):
        timer.tic()
        ft_adjust_learning_rate(optimizer, args.ft_lr, epoch, args.ft_lr_steps)
        train_prec1, train_loss = train(args, tc, train_data_loader, model,
                                        train_criterion, optimizer, epoch,
                                        recorder)
        # train_prec1, train_loss = random.random() * 100, random.random()
        recorder.record_ft_train(train_loss / 5.0, train_prec1 / 100.0)
        if (epoch + 1) % args.ft_eval_freq == 0:
            val_prec1, val_loss = validate(args, tc, val_data_loader, model,
                                           val_criterion, recorder)
            # val_prec1, val_loss = random.random() * 100, random.random()
            recorder.record_ft_val(val_loss / 5.0, val_prec1 / 100.0)
            is_best = val_prec1 > best_prec1
            best_prec1 = max(val_prec1, best_prec1)
            checkpoint = {
                'epoch': epoch + 1,
                'arch': "i3d",
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1
            }
        recorder.save_ft_model(checkpoint, is_best)
        timer.toc()
        left_time = timer.average_time * (args.ft_epochs - epoch)
        message = "Step2: fine tune best_prec1 is: {} left time is : {} now is : {}".format(
            best_prec1, timer.format(left_time), datetime.now())
        print(message)
        recorder.record_message('a', message)
    return recorder.filename
Exemple #2
0
def pretext_train(args, recorder):
    if args.gpus is not None:
        print("Use GPU: {} for pretext training".format(args.gpus))
    num_class, data_length, image_tmpl = pt_data_config(args)
    # print("tp_length is: ", data_length)
    train_transforms, test_transforms, eval_transforms = pt_augmentation_config(
        args)
    train_loader, val_loader, eval_loader, train_samples, val_samples, eval_samples = pt_data_loader_init(
        args, data_length, image_tmpl, train_transforms, test_transforms,
        eval_transforms)

    n_data = len(train_loader)

    model, model_ema = pt_model_config(args, num_class)
    # == optim config==
    contrast, criterion, optimizer = pt_optim_init(args, model, n_data)
    model = model.cuda()
    # == load weights ==
    model, model_ema = pt_load_weight(args, model, model_ema, optimizer,
                                      contrast)
    if args.pt_method in ['dsm', 'moco']:
        model_ema = model_ema.cuda()
        # copy weights from `model' to `model_ema'
        moment_update(model, model_ema, 0)
    cudnn.benchmark = True
    # optionally resume from a checkpoint
    args.start_epoch = 1

    # ==================================== our data augmentation method=================================
    if args.pt_method in ['dsm', 'dsm_triplet']:
        pos_aug = GenPositive()
        neg_aug = GenNegative()

    # =======================================add message =====================
    recorder.record_message('a', '=' * 100)
    recorder.record_message('a', '-' * 40 + 'pretrain' + '-' * 40)
    recorder.record_message('a', '=' * 100)
    # ====================update lr_decay from str to numpy=========
    iterations = args.pt_lr_decay_epochs.split(',')
    args.pt_lr_decay_epochs = list([])
    for it in iterations:
        args.pt_lr_decay_epochs.append(int(it))
    timer = Timer()
    # routine
    print('*' * 70 + 'Step1: pretrain' + '*' * 20 + '*' * 50)
    for epoch in range(args.pt_start_epoch, args.pt_epochs + 1):
        timer.tic()
        pt_adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        if args.pt_method == "moco":
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args,
                                    recorder)
        elif args.pt_method == "dsm":
            loss, prob = train_dsm(epoch, train_loader, model, model_ema,
                                   contrast, criterion, optimizer, args,
                                   pos_aug, neg_aug, recorder)
        # loss, prob = epoch * 0.01, 0.02*epoch
        elif args.pt_method == "dsm_triplet":
            loss = train_dsm_triplet(epoch, train_loader, model, optimizer,
                                     args, pos_aug, neg_aug, recorder)
        else:
            Exception("Not support method now!")
        recorder.record_pt_train(loss)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        timer.toc()
        left_time = timer.average_time * (args.pt_epochs - epoch)
        message = "Step1: pretrain now loss is: {} left time is : {} now is: {}".format(
            loss, timer.format(left_time), datetime.now())
        print(message)
        recorder.record_message('a', message)
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        recorder.save_pt_model(args, state, epoch)
    print("finished pretrain, the trained model is record in: {}".format(
        recorder.pt_checkpoint))
    return recorder.pt_checkpoint
Exemple #3
0
def train_and_eval(args):
    # =
    global lowest_val_loss, best_prec1
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # close the warning
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    torch.manual_seed(1)
    cudnn.benchmark = True
    timer = Timer()
    recorder = Record(args)
    # == dataset config==
    num_class, data_length, image_tmpl = data_config(args)
    train_transforms, test_transforms = augmentation_config(args)
    train_data_loader, val_data_loader = data_loader_init(
        args, data_length, image_tmpl, train_transforms, test_transforms)
    # == model config==
    models = []
    optimizers = []
    for i in range(args.mutual_num):
        model = model_config(args, num_class)
        models.append(model)
    recorder.record_message('a', '=' * 100)
    recorder.record_message('a', str(model.module))
    recorder.record_message('a', '=' * 100)
    # == optim config==
    for i in range(args.mutual_num):
        train_criterion, val_criterion, optimizer = optim_init(args, model)
        optimizers.append(optimizer)
    # == data augmentation(self-supervised) config==
    tc = TC(args)
    # == train and eval==
    for epoch in range(args.start_epoch, args.epochs):
        timer.tic()
        for i in range(args.mutual_num):
            adjust_learning_rate(optimizers[i], args.lr, epoch, args.lr_steps)
        if args.eval_indict == 'acc':
            train_prec1, train_loss = train(args, tc, train_data_loader,
                                            models, train_criterion,
                                            optimizers, epoch, recorder)
            # train_prec1, train_loss = random.random() * 100, random.random()
            recorder.record_train(train_loss / 5.0, train_prec1 / 100.0)
        else:
            train_loss = train(args, tc, train_data_loader, models,
                               train_criterion, optimizers, epoch, recorder)
            # train_prec1, train_loss = random.random() * 100, random.random()
            recorder.record_train(train_loss)
        if (epoch + 1) % args.eval_freq == 0:
            if args.eval_indict == 'acc':
                val_prec1, val_loss = validate(args, tc, val_data_loader,
                                               models, val_criterion, recorder)
                # val_prec1, val_loss = random.random() * 100, random.random()
                recorder.record_val(val_loss / 5.0, val_prec1 / 100.0)
                is_best = val_prec1 > best_prec1
                best_prec1 = max(val_prec1, best_prec1)
                checkpoint = {
                    'epoch': epoch + 1,
                    'arch': "i3d",
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1
                }
            else:
                val_loss = validate(args, tc, val_data_loader, models,
                                    val_criterion, recorder)
                # val_loss = random.random()
                # val_prec1, val_loss = random.random() * 100, random.random()
                recorder.record_val(val_loss)
                is_best = val_loss < lowest_val_loss
                lowest_val_loss = min(val_loss, lowest_val_loss)
                checkpoint = {
                    'epoch': epoch + 1,
                    'arch': "i3d",
                    'state_dict': model.state_dict(),
                    'lowest_val': lowest_val_loss
                }
        recorder.save_model(checkpoint, is_best)
        timer.toc()
        left_time = timer.average_time * (args.epochs - epoch)

        if args.eval_indict == 'acc':
            message = "best_prec1 is: {} left time is : {}".format(
                best_prec1, timer.format(left_time))
        else:
            message = "lowest_val_loss is: {} left time is : {}".format(
                lowest_val_loss, timer.format(left_time))
        print(message)
        recorder.record_message('a', message)
    # return recorder.best_name
    return recorder.filename