示例#1
0
def test(opt, testset, model):
  print('Begin testing')
  # test for submit
  if opt.return_test_rank:
    for dataname in testset.data_name:
      sims = test_retrieval.test(opt, model, testset, dataname)
      np.save(opt.log_dir + '/test.{}.{}.scores.npy'.format(dataname, opt.model), sims)
    exit()

  tests = []
  rsum = 0
  for dataname in testset.data_name:
    t, sims = test_retrieval.test(opt, model, testset, dataname)
    np.save(opt.log_dir + '/val.{}.{}.scores.npy'.format(dataname, opt.model), sims)
    for metric_name, metric_value in t:
      tests += [(metric_name, metric_value)]
      rsum += metric_value
  tests += [('rmean', rsum / 6)]

  for metric_name, metric_value in tests:
    print('    ', metric_name, round(metric_value, 2))
  print('Finished testing')
示例#2
0
def main():
    opt = parse_opt()
    print('Arguments:')
    for k in opt.__dict__.keys():
        print('    ', k, ':', str(opt.__dict__[k]))

    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    loss_weights = [1.0, 0.1, 0.1, 0.01]
    logdir = os.path.join(
        opt.log_dir, current_time + '_' + socket.gethostname() + opt.comment)

    logger = SummaryWriter(logdir)
    print('Log files saved to', logger.file_writer.get_logdir())
    for k in opt.__dict__.keys():
        logger.add_text(k, str(opt.__dict__[k]))

    trainset, testset = load_dataset(opt)
    model, optimizer = create_model_and_optimizer(
        opt, [t for t in trainset.get_all_texts()])
    if opt.test_only:
        print('Doing test only')
        checkpoint = torch.load(opt.model_checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])
        it = checkpoint['it']
        model.eval()
        tests = []
        it = 0
        for name, dataset in [('train', trainset), ('test', testset)]:
            if opt.dataset == 'fashionIQ':
                t = test_retrieval.fiq_test(opt, model, dataset)
            else:
                t = test_retrieval.test(opt, model, dataset)
            tests += [(name + ' ' + metric_name, metric_value)
                      for metric_name, metric_value in t]
        for metric_name, metric_value in tests:
            logger.add_scalar(metric_name, metric_value, it)
            print('    ', metric_name, round(metric_value, 4))

        return 0
    train_loop(opt, loss_weights, logger, trainset, testset, model, optimizer)
    logger.close()
示例#3
0
def run_eval(opt, logger, dataset_dict, model, it, eval_on_test=False):
    trainset = dataset_dict["train"]
    if eval_on_test:
        testset = dataset_dict["test"]
    else:
        testset = dataset_dict.get("val", dataset_dict["test"])
    tests = []
    for name, dataset in [('train', trainset), ('test', testset)]:
        categ = opt.dataset == "fashioniq" and name == 'test'
        t = test_retrieval.test(opt, model, dataset, filter_categories=categ)
        tests += [(name + ' ' + metric_name, metric_value)
                  for metric_name, metric_value in t]
    for metric_name, metric_value in tests:
        logger.add_scalar(metric_name, metric_value, it)
        print('    ', metric_name, round(metric_value, 4))
    if opt.dataset == "fashioniq":
        scores = [
            metric for name, metric in tests
            if "test" in name and ("top10_" in name or "top50_" in name)
        ]
        fiq_score = np.mean(scores)
        logger.add_scalar("fiq_score", fiq_score, it)
        print('    ', 'fiq_score', round(fiq_score, 4))
def train_loop(opt, trainset, testset, model, optimizer):
    """Function for train loop"""
    print('Begin training')
    losses_tracking = {}
    it = 0
    epoch = -1
    tic = time.time()
    while it < opt.num_iters:
        epoch += 1

        # show/log stats
        print('It', it, 'epoch', epoch, 'Elapsed time',
              round(time.time() - tic, 4), opt.comment)
        tic = time.time()
        for loss_name in losses_tracking:
            avg_loss = np.mean(losses_tracking[loss_name][-len(trainloader):])

            print('    Loss', loss_name, round(avg_loss, 4))
            #logger.add_scalar(loss_name, avg_loss, it)

        #logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], it)

        # test
        if epoch % 3 == 1:
            tests = []
            for name, dataset in [('train', trainset), ('test', testset)]:
                t = test_retrieval.test(opt, model, dataset)
                tests += [(name + ' ' + metric_name, metric_value)
                          for metric_name, metric_value in t]

            for metric_name, metric_value in tests:
                #  logger.add_scalar(metric_name, metric_value, it)
                print('    ', metric_name, round(metric_value, 4))

        # save checkpoint
        torch.save(
            {
                'it': it,
                'opt': opt,
                'model_state_dict': model.state_dict(),
            },
            '/content/drive/My Drive/colab_model/tirg/latest_checkpoint.pth')

        # run trainning for 1 epoch
        model.train()
        trainloader = trainset.get_loader(batch_size=opt.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          num_workers=opt.loader_num_workers)

        def training_1_iter(data):
            assert type(data) is list
            img1 = np.stack([d['source_img_data'] for d in data])
            img1 = torch.from_numpy(img1).float()
            img1 = torch.autograd.Variable(img1).cuda()
            img2 = np.stack([d['target_img_data'] for d in data])
            img2 = torch.from_numpy(img2).float()
            img2 = torch.autograd.Variable(img2).cuda()
            mods = [str(d['mod']['str']) for d in data]
            mods = [t.decode('utf-8') for t in mods]

            # compute loss
            losses = []
            if opt.loss == 'soft_triplet':
                loss_value = model.compute_loss(img1,
                                                mods,
                                                img2,
                                                soft_triplet_loss=True)
            elif opt.loss == 'batch_based_classification':
                loss_value = model.compute_loss(img1,
                                                mods,
                                                img2,
                                                soft_triplet_loss=False)
            else:
                print('Invalid loss function', opt.loss)
                sys.exit()

            loss_name = opt.loss
            loss_weight = 1.0
            losses += [(loss_name, loss_weight, loss_value)]

            total_loss = sum([
                loss_weight * loss_value
                for loss_name, loss_weight, loss_value in losses
            ])
            assert not torch.isnan(total_loss)
            losses += [('total training loss', None, total_loss)]

            # track losses
            for loss_name, loss_weight, loss_value in losses:
                if not losses_tracking.has_key(loss_name):
                    losses_tracking[loss_name] = []
                losses_tracking[loss_name].append(float(loss_value))

            # gradient descend
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        for data in tqdm(trainloader, desc='Training for epoch ' + str(epoch)):
            it += 1
            training_1_iter(data)

            # decay learing rate
            if it >= opt.learning_rate_decay_frequency and it % opt.learning_rate_decay_frequency == 0:
                for g in optimizer.param_groups:
                    g['lr'] *= 0.1

    print('Finished training')
示例#5
0
def train_loop(opt, loss_weights, logger, trainset, testset, model, optimizer):
    """Function for train loop"""
    print('Begin training')
    print(len(trainset.test_queries), len(testset.test_queries))
    torch.backends.cudnn.benchmark = True
    losses_tracking = {}
    it = 0
    epoch = -1
    tic = time.time()
    l2_loss = torch.nn.MSELoss().cuda()

    while it < opt.num_iters:
        epoch += 1

        # show/log stats
        print('It', it, 'epoch', epoch, 'Elapsed time',
              round(time.time() - tic, 4), opt.comment)
        tic = time.time()
        for loss_name in losses_tracking:
            avg_loss = np.mean(losses_tracking[loss_name][-len(trainloader):])
            print('    Loss', loss_name, round(avg_loss, 4))
            logger.add_scalar(loss_name, avg_loss, it)
        logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], it)

        if epoch % 1 == 0:
            gc.collect()

        # test
        if epoch % 3 == 1:
            tests = []
            for name, dataset in [('train', trainset), ('test', testset)]:
                if opt.dataset == 'fashionIQ':
                    t = test_retrieval.fiq_test(opt, model, dataset)
                else:
                    t = test_retrieval.test(opt, model, dataset)
                tests += [(name + ' ' + metric_name, metric_value)
                          for metric_name, metric_value in t]
            for metric_name, metric_value in tests:
                logger.add_scalar(metric_name, metric_value, it)
                print('    ', metric_name, round(metric_value, 4))

        # save checkpoint
        torch.save(
            {
                'it': it,
                'opt': opt,
                'model_state_dict': model.state_dict(),
            },
            logger.file_writer.get_logdir() + '/latest_checkpoint.pth')

        # run training for 1 epoch
        model.train()
        trainloader = trainset.get_loader(batch_size=opt.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          num_workers=opt.loader_num_workers)

        def training_1_iter(data):
            assert type(data) is list
            img1 = np.stack([d['source_img_data'] for d in data])
            img1 = torch.from_numpy(img1).float()
            img1 = torch.autograd.Variable(img1).cuda()

            img2 = np.stack([d['target_img_data'] for d in data])
            img2 = torch.from_numpy(img2).float()
            img2 = torch.autograd.Variable(img2).cuda()

            if opt.use_complete_text_query:
                if opt.dataset == 'mitstates':
                    supp_text = [str(d['noun']) for d in data]
                    mods = [str(d['mod']['str']) for d in data]
                    # text_query here means complete_text_query
                    text_query = [
                        adj + " " + noun for adj, noun in zip(mods, supp_text)
                    ]
                else:
                    text_query = [str(d['target_caption']) for d in data]
            else:
                text_query = [str(d['mod']['str']) for d in data]
            # compute loss
            if opt.loss not in ['soft_triplet', 'batch_based_classification']:
                print('Invalid loss function', opt.loss)
                sys.exit()

            losses = []
            if_soft_triplet = True if opt.loss == 'soft_triplet' else False
            loss_value, dct_with_representations = model.compute_loss(
                img1, text_query, img2, soft_triplet_loss=if_soft_triplet)

            loss_name = opt.loss
            losses += [(loss_name, loss_weights[0], loss_value.cuda())]

            if opt.model == 'composeAE':
                dec_img_loss = l2_loss(
                    dct_with_representations["repr_to_compare_with_source"],
                    dct_with_representations["img_features"])
                dec_text_loss = l2_loss(
                    dct_with_representations["repr_to_compare_with_mods"],
                    dct_with_representations["text_features"])

                losses += [("L2_loss", loss_weights[1], dec_img_loss.cuda())]
                losses += [("L2_loss_text", loss_weights[2],
                            dec_text_loss.cuda())]
                losses += [("rot_sym_loss", loss_weights[3],
                            dct_with_representations["rot_sym_loss"].cuda())]
            elif opt.model == 'RealSpaceConcatAE':
                dec_img_loss = l2_loss(
                    dct_with_representations["repr_to_compare_with_source"],
                    dct_with_representations["img_features"])
                dec_text_loss = l2_loss(
                    dct_with_representations["repr_to_compare_with_mods"],
                    dct_with_representations["text_features"])

                losses += [("L2_loss", loss_weights[1], dec_img_loss.cuda())]
                losses += [("L2_loss_text", loss_weights[2],
                            dec_text_loss.cuda())]

            total_loss = sum([
                loss_weight * loss_value
                for loss_name, loss_weight, loss_value in losses
            ])
            assert not torch.isnan(total_loss)
            losses += [('total training loss', None, total_loss.item())]

            # track losses
            for loss_name, loss_weight, loss_value in losses:
                if loss_name not in losses_tracking:
                    losses_tracking[loss_name] = []
                losses_tracking[loss_name].append(float(loss_value))

            torch.autograd.set_detect_anomaly(True)

            # gradient descendt
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        for data in tqdm(trainloader, desc='Training for epoch ' + str(epoch)):
            it += 1
            training_1_iter(data)

            # decay learning rate
            if it >= opt.learning_rate_decay_frequency and it % opt.learning_rate_decay_frequency == 0:
                for g in optimizer.param_groups:
                    g['lr'] *= 0.1

    print('Finished training')
示例#6
0
def GetValuestrain15time():

    with open(Path1 + "/trainBetaNormalized.txt", 'rb') as fp:
        BetaNormalize = pickle.load(fp)

    trainset = datasets.Fashion200k(
        path=Path1,
        split='train',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                             [0.229, 0.224, 0.225])
        ]))

    trainloader = trainset.get_loader(batch_size=2,
                                      shuffle=True,
                                      drop_last=True,
                                      num_workers=0)

    testset = TestFashion200k(
        path=Path1,
        split='test',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                             [0.229, 0.224, 0.225])
        ]))

    trig = TIRG([t.encode().decode('utf-8') for t in trainset.get_all_texts()],
                512)
    trig.load_state_dict(
        torch.load(Path1 + r'\checkpoint_fashion200k.pth',
                   map_location=torch.device('cpu'))['model_state_dict'])

    opt = argparse.ArgumentParser()
    opt.add_argument('--batch_size', type=int, default=2)
    opt.add_argument('--dataset', type=str, default='fashion200k')
    opt.batch_size = 1
    opt.dataset = 'fashion200k'

    Results = []

    for i in range(15):
        for name, dataset in [('train', trainset)]:  #,('test', testset)]:

            # betaNor="['1 ---> 5.27', '5 ---> 14.39', '10 ---> 21.6', '50 ---> 43.830000000000005', '100 ---> 55.33']"
            # Results.append('No.'+str(i)+' DataSet='+name+' Type= BetaNormalized '+' Result=' +betaNor)
            try:

                betaNor = test_retrieval.testbetanormalizednot(
                    opt, trig, dataset, BetaNormalize)
                print(name, ' BetaNormalized: ', betaNor)
                Results.append('No.' + str(i) + ' DataSet=' + name +
                               ' Type= BetaNormalized ' + ' Result=' + betaNor)
            except:
                print('ERROR')

            try:
                asbook = test_retrieval.test(opt, trig, dataset)
                print(name, ' As PaPer: ', asbook)
                Results.append('No.' + str(i) + ' DataSet=' + name +
                               ' Type= As PaPer ' + ' Result=' + betaNor)
            except:
                print('ERROR')

    with open(Path1 + r"/" + 'Results15time.txt', 'wb') as fp:
        pickle.dump(Results, fp)
示例#7
0
def GetValuestrain():

    with open(Path1 + "/trainBetaNormalized.txt", 'rb') as fp:
        BetaNormalize = pickle.load(fp)

    trainset = TestFashion200k(
        path=Path1,
        split='train',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                             [0.229, 0.224, 0.225])
        ]))

    trainloader = trainset.get_loader(batch_size=2,
                                      shuffle=True,
                                      drop_last=True,
                                      num_workers=0)

    testset = TestFashion200k(
        path=Path1,
        split='test',
        transform=torchvision.transforms.Compose([
            torchvision.transforms.Resize(224),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                             [0.229, 0.224, 0.225])
        ]))

    trig = TIRG([t.encode().decode('utf-8') for t in trainset.get_all_texts()],
                512)
    #trig.load_state_dict(torch.load(Path1+r'\checkpoint_fashion200k.pth' , map_location=torch.device('cpu') )['model_state_dict'])
    trig.load_state_dict(
        torch.load(Path1 + r'\fashion200k.tirg.iter160k.pth',
                   map_location=torch.device('cpu'))['model_state_dict'])

    opt = argparse.ArgumentParser()
    opt.add_argument('--batch_size', type=int, default=2)
    opt.add_argument('--dataset', type=str, default='fashion200k')
    opt.batch_size = 1
    opt.dataset = 'fashion200k'

    for name, dataset in [('train', trainset),
                          ('test', testset)]:  #('train', trainset),

        # betaN = test_retrieval.testbetaNot(opt, trig, dataset,Beta)
        # print('BetaNotNormalized: ',betaN)

        #try:
        betaNor = test_retrieval.testbetanormalizednot(opt, trig, dataset,
                                                       BetaNormalize)
        print(name, ' BetaNormalized: ', betaNor)
        #except:
        #  print('ERROR')

        #try:
        asbook = test_retrieval.test(opt, trig, dataset)
        print(name, ' As PaPer: ', asbook)
示例#8
0
文件: main.py 项目: bbttboy/MyCAE
def train_loop(opt, loss_weights, logger, trainset, testset, model, optimizer):
    """Function for train loop"""
    print('Begin training.')
    print(len(trainset.test_queries), len(testset.test_queries))
    # 将会让程序在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速
    # 适用:适用场景是网络结构固定(不是动态变化的),网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的
    torch.backends.cudnn.benchmark = True
    losses_tracking = {}
    it = 0  # 迭代次数
    epoch = -1
    tic = time.time()
    l2_loss = torch.nn.MSELoss().cuda()

    while it < opt.num_iters:
        epoch += 1

        # show/log stats
        # round(x, n) --> n表示保留小数点后n位(四舍五入)
        # print(x, y, z)的写法可以不管x,y,z是不是str,同时会自动在x,y,z之间加空格
        print('It', it, 'epoch', epoch, 'Elapsed time',
              round(time.time() - tic, 4), opt.comment)

        tic = time.time()
        for loss_name in losses_tracking:
            avg_loss = np.mean(losses_tracking[loss_name][-len(trainloader):])
            print('    Loss', loss_name, round(avg_loss, 4))
            logger.add_scalar(loss_name, avg_loss, it)
        # 通用api
        # 通用格式 add_something(tag name, object, iteration number)
        logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], it)

        # 回收被销毁了但是没有被释放的循环引用的对象
        if epoch % 1 == 0:
            gc.collect()

        # test
        if epoch % 3 == 1:
            tests = []
            print('Begain testing.')
            for name, dataset in [('train', trainset), ('test', testset)]:
                t = test_retrieval.test(opt, model, dataset)
                tests += [(name + ' ' + metric_name, metric_value)
                          for metric_name, metric_value in t]
            for metric_name, metric_value in tests:
                logger.add_scalar(metric_name, metric_value, it)
                print('    ', metric_name, round(metric_value, 4))

        # save checkpoint
        torch.save(
            {
                'it': it,
                'opt': opt,
                'model_state_dict': model.state_dict(),
            },
            logger.file_writer.get_logdir() + '/latest_checkpoint.pth')

        # run training for 1 epoch
        model.train()
        trainloader = trainset.get_loader(batch_size=opt.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          num_workers=opt.loader_num_workers)

        def training_1_iter(data):
            assert type(data) is list
            img1 = np.stack([d['source_img_data'] for d in data])
            img1 = torch.from_numpy(img1).float()
            img1 = torch.autograd.Variable(img1).cuda()

            img2 = np.stack([d['target_img_data'] for d in data])
            img2 = torch.from_numpy(img2).float()
            img2 = torch.autograd.Variable(img2).cuda()

            if opt.use_complete_text_query:
                if opt.dataset == 'mitstates':
                    supp_text = [str(d['noun']) for d in data]
                    mods = [str(d['mod']['str']) for d in data]
                    # text_query here means complete_text_query
                    text_query = [
                        adj + " " + noun for adj, noun in zip(mods, supp_text)
                    ]
                else:
                    text_query = [str(d['target_caption']) for d in data]
            else:
                text_query = [str(d['mod']['str']) for d in data]
            # compute loss
            if opt.loss not in ['soft_triplet', 'batch_based_classification']:
                print('Invalid loss function', opt.loss)
                sys.exit()

            losses = []
            if_soft_triplet = True if opt.loss == 'soft_triplet' else False
            loss_value, dct_with_representations = model.compute_loss(
                img1,
                text_query,
                img2,
                soft_triplet_loss=if_soft_triplet,
            )

            loss_name = opt.loss
            losses += [(loss_name, loss_weights[0], loss_value.cuda())]

            if opt.model == 'composeAE':
                dec_img_loss = l2_loss(
                    dct_with_representations["repr_to_compare_with_source"],
                    dct_with_representations["img_features"])
                dec_text_loss = l2_loss(
                    dct_with_representations["repr_to_compare_with_mods"],
                    dct_with_representations["text_features"])

                losses += [("L2_loss", loss_weights[1], dec_img_loss.cuda())
                           ]  # loss_weight=0.1
                losses += [("L2_loss_text", loss_weights[2],
                            dec_text_loss.cuda())]  # loss_weight=0.1
                # loss_weight=0.01
                losses += [("rot_sym_loss", loss_weights[3],
                            dct_with_representations["rot_sym_loss"].cuda())]
            else:
                print("Invalid model.", opt.model)
                sys.exit()

            total_loss = sum([
                loss_weight * loss_value
                for loss_name, loss_weight, loss_value in losses
            ])
            assert not torch.isnan(total_loss)
            losses += [('total training loss', None, total_loss.item())]

            # track losses
            for loss_name, loss_weight, loss_value in losses:
                if loss_name not in losses_tracking:
                    losses_tracking[loss_name] = []
                losses_tracking[loss_name].append(float(loss_value))

            torch.autograd.set_detect_anomaly(True)

            # gradient descendt
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        # 在处理大规模数据时或者需要迭代多次耗时很长的任务时
        # 可以利用Python tqdm模块来显示任务进度条
        # tqdm使用方法:tqdm.tqdm(可迭代对象) ,括号中的可迭代对象可以是个list,tuple,dict等。
        # 这里直接使用tqdm没有用tqdm.tqdm是因为 from tqdm import tqdm
        # 即tqdm只是一个tqdm.py 需要从中 import tqdm函数
        for data in tqdm(trainloader, desc='Training for epoch ' + str(epoch)):
            it += 1
            training_1_iter(data)

            # decay learning rate
            if it >= opt.learning_rate_decay_frequency and it % opt.learning_rate_decay_frequency == 0:
                for g in optimizer.param_groups:
                    g['lr'] *= 0.1

        print('Finished training.')
示例#9
0
def train_loop(opt, logger, trainset, testset, model, optimizer):
    """Function for train loop"""
    print('Begin training')
    losses_tracking = {}
    it = 0
    epoch = -1
    tic = time.time()
    best_metric = 0
    while epoch < opt.num_epochs:
        epoch += 1

        # show/log stats
        print('It', it, 'epoch', epoch, 'Elapsed time',
              round(time.time() - tic, 4), opt.comment)
        tic = time.time()
        for loss_name in losses_tracking:
            avg_loss = np.mean(losses_tracking[loss_name][-len(trainloader):])
            print('    Loss', loss_name, round(avg_loss, 4))
            logger.add_scalar(loss_name, avg_loss, it)
        logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], it)
        is_improved = False
        # test for first and for every n_epochs
        if (epoch + 1) % opt.n_epochs_validation == 0 or epoch == 0:
            tests = []
            for name, dataset in [('train', trainset), ('test', testset)]:
                if opt.skip_eval_trainset and name == "train":
                    continue
                t, _ = test_retrieval.test(opt, model, dataset)
                tests += [(name + ' ' + metric_name, metric_value)
                          for metric_name, metric_value in t]
                recall_top_1 = t[0][1]
                if name == "test" and recall_top_1 > best_metric:
                    best_metric = recall_top_1
                    is_improved = True
            for metric_name, metric_value in tests:
                logger.add_scalar(metric_name, metric_value, it)
                print('    ', metric_name, round(metric_value, 4))
        metric_val = tests[0][0]

        if is_improved:
            # save checkpoint if improved from last checkpoint
            print(f"Is Improved {best_metric}")
            torch.save(
                {
                    'it': it,
                    'opt': opt,
                    'model_state_dict': model.state_dict(),
                },
                logger.file_writer.get_logdir() + '/best_checkpoint.pth')

        # run trainning for 1 epoch
        model.train()
        trainloader = trainset.get_loader(batch_size=opt.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          num_workers=opt.loader_num_workers)

        def training_1_iter(data):
            assert type(data) is list
            img1 = np.stack([d['source_img_data'] for d in data])
            img1 = torch.from_numpy(img1).float()
            img1 = torch.autograd.Variable(img1).cuda()
            img2 = np.stack([d['target_img_data'] for d in data])
            img2 = torch.from_numpy(img2).float()
            img2 = torch.autograd.Variable(img2).cuda()
            mods = [str(d['mod']['str']) for d in data]
            mods = [t for t in mods]

            # compute loss
            losses = []
            if opt.loss == 'soft_triplet':
                loss_value = model.compute_loss(img1,
                                                mods,
                                                img2,
                                                soft_triplet_loss=True)
            elif opt.loss == 'batch_based_classification':
                loss_value = model.compute_loss(img1,
                                                mods,
                                                img2,
                                                soft_triplet_loss=False)
            else:
                print('Invalid loss function', opt.loss)
                sys.exit()
            loss_name = opt.loss
            loss_weight = 1.0
            losses += [(loss_name, loss_weight, loss_value)]
            total_loss = sum([
                loss_weight * loss_value
                for loss_name, loss_weight, loss_value in losses
            ])
            assert not torch.isnan(total_loss)
            losses += [('total training loss', None, total_loss)]

            # track losses
            for loss_name, loss_weight, loss_value in losses:
                if loss_name not in losses_tracking:
                    losses_tracking[loss_name] = []
                losses_tracking[loss_name].append(float(loss_value))

            # gradient descend
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        for data in tqdm(trainloader, desc='Training for epoch ' + str(epoch)):
            it += 1
            training_1_iter(data)

            # decay learing rate
            if it >= opt.learning_rate_decay_frequency and it % opt.learning_rate_decay_frequency == 0:
                for g in optimizer.param_groups:
                    g['lr'] *= 0.1

    print('Finished training')
示例#10
0
def train_loop(opt, logger, trainset, testset, model, optimizer, DP_data=None):
  """Function for train loop"""
  print('Begin training')
  losses_tracking = {}
  best_eval = 0
  it = 0
  tic = time.time()
  
  for epoch in range(opt.num_epoch):
    # decay learing rate epoch
    if epoch != 0 and epoch % opt.learning_rate_decay_patient == 0:
      for g in optimizer.param_groups:
        g['lr'] *= opt.lr_div

    # run trainning for 1 epoch
    model.train()
    trainloader = dataloader.DataLoader(trainset,
        batch_size=opt.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=opt.loader_num_workers)

    def training_1_iter(data, data_dp=None):
      img1 = data['source_img_data'].cuda()
      img2 = data['target_img_data'].cuda()
      mods = data['mod']['str']

      # compute loss
      losses = []
      if opt.loss == 'batch_based_classification':
        loss_value = model.compute_loss(
            img1, mods, img2, soft_triplet_loss=False)
      else:
        print('Invalid loss function', opt.loss)
        sys.exit()
      loss_name = opt.loss
      loss_weight = 1.0
      losses += [(loss_name, loss_weight, loss_value)]

      total_loss = sum([
          loss_weight * loss_value
          for loss_name, loss_weight, loss_value in losses
      ])
      assert not torch.isnan(total_loss)
      losses += [('total training loss', None, total_loss)]

      # track losses
      for loss_name, loss_weight, loss_value in losses:
        if not loss_name in losses_tracking:
          losses_tracking[loss_name] = []
        losses_tracking[loss_name].append(float(loss_value))

      # gradient descend
      optimizer.zero_grad()
      total_loss.backward()
      optimizer.step()
    
    count_dp_idx = 0
    for data in tqdm(trainloader, desc='Training for epoch ' + str(epoch)):
      it += 1
      training_1_iter(data)
    
    # show/log stats
    print('It', it, 'epoch', epoch, 'Elapsed time', round(time.time() - tic, 4))
    tic = time.time()
    for loss_name in losses_tracking:
      avg_loss = np.mean(losses_tracking[loss_name][-len(trainloader):])
      print('    Loss', loss_name, round(avg_loss, 4))
      logger.add_scalar(loss_name, avg_loss, it)
    logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], it)

    # test
    if epoch % opt.eval_frequency == 0:
      tests = []
      all_sims = {}
      rsum = 0
      for dataname in testset.data_name:
        t, sims = test_retrieval.test(opt, model, testset, dataname)
        all_sims[dataname] = sims
        for metric_name, metric_value in t:
          tests += [(metric_name, metric_value)]
          rsum += metric_value
      tests += [('rmean', rsum / 6)]
      for metric_name, metric_value in tests:
        logger.add_scalar(metric_name, metric_value, it)
        print('    ', metric_name, round(metric_value, 2))

      if rsum > best_eval:
        best_eval = rsum
        # save checkpoint
        for dataname in testset.data_name:
          np.save(opt.log_dir + '/val.{}.{}.scores.npy'.format(dataname, opt.model), all_sims[dataname])
        torch.save({
            'it': it,
            'opt': opt,
            'model_state_dict': model.state_dict(),
        },
        logger.file_writer.get_logdir() + '/best_checkpoint.pth')

  print('Finished training')