Beispiel #1
0
def prepare_training():
    if config.get('resume') is not None:
        sv_file = torch.load(config['resume'])
        model = models.make(sv_file['model'], load_sd=True).cuda()
        optimizer = utils.make_optimizer(model.parameters(),
                                         sv_file['optimizer'],
                                         load_sd=True)
        epoch_start = sv_file['epoch'] + 1
        if config.get('multi_step_lr') is None:
            lr_scheduler = None
        else:
            lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])
        for _ in range(epoch_start - 1):
            lr_scheduler.step()
    else:
        model = models.make(config['model']).cuda()
        optimizer = utils.make_optimizer(model.parameters(),
                                         config['optimizer'])
        epoch_start = 1
        if config.get('multi_step_lr') is None:
            lr_scheduler = None
        else:
            lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])

    log('model: #params={}'.format(utils.compute_num_params(model, text=True)))
    return model, optimizer, epoch_start, lr_scheduler
Beispiel #2
0
def main():
    model = get_model()
    device = torch.device('cuda')
    model = model.to(device)
    loader = data.Data(args).train_loader
    rank = torch.Tensor([i for i in range(101)]).cuda()
    for i in range(args.epochs):
        lr = 0.001 if i < 30 else 0.0001
        optimizer = utils.make_optimizer(args, model, lr)
        model.train()
        print('Learning rate:{}'.format(lr))
        start_time = time.time()
        for j, inputs in enumerate(loader):
            img, label, age = inputs
            img = img.to(device)
            label = label.to(device)
            age = age.to(device)
            optimizer.zero_grad()
            outputs = model(img)
            ages = torch.sum(outputs*rank, dim=1)
            loss1 = loss.kl_loss(outputs, label)
            loss2 = loss.L1_loss(ages, age)
            total_loss = loss1 + loss2
            total_loss.backward()
            optimizer.step()
            current_time = time.time()
            print('[Epoch:{}] \t[batch:{}]\t[loss={:.4f}]'.format(i, j, total_loss.item()))
            start_time = time.time()
        torch.save(model, './pretrained/{}.pt'.format(args.model_name))
        torch.save(model.state_dict(), './pretrained/{}_dict.pt'.format(args.model_name))
        print('Test: Epoch=[{}]'.format(i))
        if (i+1) % 2 == 0:
            test()
Beispiel #3
0
def main():
    args = get_args()
    torch.manual_seed(args.seed)

    shape = (224, 224, 3)
    """ define dataloader """
    train_loader, valid_loader, test_loader = make_dataloader(args)
    """ define model architecture """
    model = get_model(args, shape, args.num_classes)

    if torch.cuda.device_count() >= 1:
        print('Model pushed to {} GPU(s), type {}.'.format(
            torch.cuda.device_count(), torch.cuda.get_device_name(0)))
        model = model.cuda()
    else:
        raise ValueError('CPU training is not supported')
    """ define loss criterion """
    criterion = nn.CrossEntropyLoss().cuda()
    """ define optimizer """
    optimizer = make_optimizer(args, model)
    """ define learning rate scheduler """
    scheduler = make_scheduler(args, optimizer)
    """ define trainer, evaluator, result_dictionary """
    result_dict = {
        'args': vars(args),
        'epoch': [],
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'test_acc': []
    }
    trainer = Trainer(model, criterion, optimizer, scheduler)
    evaluator = Evaluator(model, criterion)

    if args.evaluate:
        """ load model checkpoint """
        model.load()
        result_dict = evaluator.test(test_loader, args, result_dict)
    else:
        evaluator.save(result_dict)
        """ define training loop """
        for epoch in range(args.epochs):
            result_dict['epoch'] = epoch
            result_dict = trainer.train(train_loader, epoch, args, result_dict)
            result_dict = evaluator.evaluate(valid_loader, epoch, args,
                                             result_dict)
            evaluator.save(result_dict)
            plot_learning_curves(result_dict, epoch, args)

        result_dict = evaluator.test(test_loader, args, result_dict)
        evaluator.save(result_dict)
        """ save model checkpoint """
        model.save()

    print(result_dict)
 def __init__(self, args, gan_type):
     super(Adversarial, self).__init__()
     self.gan_type = gan_type
     self.gan_k = args.gan_k
     self.discriminator = discriminator.Discriminator(args, gan_type)
     if gan_type != 'WGAN_GP':
         self.optimizer = utils.make_optimizer(args, self.discriminator)
     else:
         self.optimizer = optim.Adam(self.discriminator.parameters(),
                                     betas=(0, 0.9),
                                     eps=1e-8,
                                     lr=1e-5)
     self.scheduler = utils.make_scheduler(args, self.optimizer)
def runExperiment():
    seed = int(cfg['model_tag'].split('_')[0])
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
    process_dataset(dataset)
    model = eval('models.{}(model_rate=cfg["global_model_rate"]).to(cfg["device"])'.format(cfg['model_name']))
    optimizer = make_optimizer(model, cfg['lr'])
    scheduler = make_scheduler(optimizer)
    if cfg['resume_mode'] == 1:
        last_epoch, data_split, label_split, model, optimizer, scheduler, logger = resume(model, cfg['model_tag'],
                                                                                          optimizer, scheduler)
    elif cfg['resume_mode'] == 2:
        last_epoch = 1
        _, data_split, label_split, model, _, _, _ = resume(model, cfg['model_tag'])
        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
        logger_path = 'output/runs/{}_{}'.format(cfg['model_tag'], current_time)
        logger = Logger(logger_path)
    else:
        last_epoch = 1
        data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode'])
        current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
        logger_path = 'output/runs/train_{}_{}'.format(cfg['model_tag'], current_time)
        logger = Logger(logger_path)
    if data_split is None:
        data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode'])
    global_parameters = model.state_dict()
    federation = Federation(global_parameters, cfg['model_rate'], label_split)
    for epoch in range(last_epoch, cfg['num_epochs']['global'] + 1):
        logger.safe(True)
        train(dataset['train'], data_split['train'], label_split, federation, model, optimizer, logger, epoch)
        test_model = stats(dataset['train'], model)
        test(dataset['test'], data_split['test'], label_split, test_model, logger, epoch)
        if cfg['scheduler_name'] == 'ReduceLROnPlateau':
            scheduler.step(metrics=logger.mean['train/{}'.format(cfg['pivot_metric'])])
        else:
            scheduler.step()
        logger.safe(False)
        model_state_dict = model.state_dict()
        save_result = {
            'cfg': cfg, 'epoch': epoch + 1, 'data_split': data_split, 'label_split': label_split,
            'model_dict': model_state_dict, 'optimizer_dict': optimizer.state_dict(),
            'scheduler_dict': scheduler.state_dict(), 'logger': logger}
        save(save_result, './output/model/{}_checkpoint.pt'.format(cfg['model_tag']))
        if cfg['pivot'] < logger.mean['test/{}'.format(cfg['pivot_metric'])]:
            cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])]
            shutil.copy('./output/model/{}_checkpoint.pt'.format(cfg['model_tag']),
                        './output/model/{}_best.pt'.format(cfg['model_tag']))
        logger.reset()
    logger.safe(False)
    return
Beispiel #6
0
def main():
    model = get_model()
    device = torch.device('cuda')
    model = model.to(device)
    print(model)
    loader = data.Data(args).train_loader
    rank = torch.Tensor([i for i in range(101)]).cuda()
    best_mae = np.inf
    for i in range(args.epochs):
        lr = 0.001 if i < 30 else 0.0001
        optimizer = utils.make_optimizer(args, model, lr)
        model.train()
        print('Learning rate:{}'.format(lr))
        # start_time = time.time()
        for j, inputs in enumerate(tqdm(loader)):
            img, label, age = inputs['image'], inputs['label'], inputs['age']
            img = img.to(device)
            label = label.to(device)
            age = age.to(device)
            optimizer.zero_grad()
            outputs = model(img)
            ages = torch.sum(outputs * rank, dim=1)
            loss1 = loss.kl_loss(outputs, label)
            loss2 = loss.L1_loss(ages, age)
            total_loss = loss1 + loss2
            total_loss.backward()
            optimizer.step()
            # current_time = time.time()
            if j % 10 == 0:
                tqdm.write('[Epoch:{}] \t[batch:{}]\t[loss={:.4f}]'.format(
                    i, j, total_loss.item()))
            # start_time = time.time()
        torch.save(model, './checkpoint/{}.pt'.format(args.model_name))
        torch.save(model.state_dict(),
                   './checkpoint/{}_dict.pt'.format(args.model_name))
        if (i + 1) % 2 == 0:
            print('Test: Epoch=[{}]'.format(i))
            cur_mae = test(model)
            if cur_mae < best_mae:
                best_mae = cur_mae
                print(f'Saving best model with MAE {cur_mae}... ')
                torch.save(
                    model, './checkpoint/best_{}_MAE={}.pt'.format(
                        args.model_name, cur_mae))
                torch.save(
                    model.state_dict(),
                    './checkpoint/best_{}_dict_MAE={}.pt'.format(
                        args.model_name, cur_mae))
Beispiel #7
0
    def __init__(self, args, train_loader, val_loader, model, loss):
        self.args = args
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.model = model
        if args.pre_train_model != "...":
            self.model.load_state_dict(torch.load(args.pre_train_model))
        self.loss = loss

        # Actually I am not sure what would happen if I specify the optimizer and scheduler in main.py
        self.optimizer = utils.make_optimizer(self.args, self.model)
        if args.pre_train_optimizer != "...":
            self.optimizer.load_state_dict(torch.load(
                args.pre_train_optimizer))
        self.scheduler = utils.make_scheduler(self.args, self.optimizer)
        self.iter = 0
 def train(self, local_parameters, lr, logger):
     metric = Metric()
     model = eval('models.{}(model_rate=self.model_rate).to(cfg["device"])'.format(cfg['model_name']))
     model.load_state_dict(local_parameters)
     model.train(True)
     optimizer = make_optimizer(model, lr)
     for local_epoch in range(1, cfg['num_epochs']['local'] + 1):
         for i, input in enumerate(self.data_loader):
             input = collate(input)
             input_size = input['img'].size(0)
             input['label_split'] = torch.tensor(self.label_split)
             input = to_device(input, cfg['device'])
             optimizer.zero_grad()
             output = model(input)
             output['loss'].backward()
             torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
             optimizer.step()
             evaluation = metric.evaluate(cfg['metric_name']['train']['Local'], input, output)
             logger.append(evaluation, 'train', n=input_size)
     local_parameters = model.state_dict()
     return local_parameters
def train(config):
    #### set the save and log path ####
    save_path = config['save_path']
    utils.set_save_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(config['save_path'], 'tensorboard'))
    yaml.dump(config, open(os.path.join(config['save_path'], 'classifier_config.yaml'), 'w'))

    device = torch.device('cuda:' + args.gpu)

    #### make datasets ####
    # train
    train_folder = config['dataset_path'] + config['train_dataset_type'] + "/training/frames"
    test_folder = config['dataset_path'] + config['train_dataset_type'] + "/testing/frames"

    # Loading dataset
    train_dataset_args = config['train_dataset_args']
    test_dataset_args = config['test_dataset_args']

    train_dataset = VadDataset(args,video_folder= train_folder, bbox_folder = config['train_bboxes_path'], flow_folder=config['train_flow_path'],
                            transform=transforms.Compose([transforms.ToTensor()]),
                            resize_height=train_dataset_args['h'], resize_width=train_dataset_args['w'],
                               dataset=config['train_dataset_type'], time_step=train_dataset_args['t_length'] - 1,
                               device=device)

    test_dataset = VadDataset(args,video_folder= test_folder, bbox_folder = config['test_bboxes_path'], flow_folder=config['test_flow_path'],
                            transform=transforms.Compose([transforms.ToTensor()]),
                            resize_height=train_dataset_args['h'], resize_width=train_dataset_args['w'],
                               dataset=config['train_dataset_type'], time_step=train_dataset_args['t_length'] - 1,
                               device=device)


    train_dataloader = DataLoader(train_dataset, batch_size=train_dataset_args['batch_size'],
                                  shuffle=True, num_workers=train_dataset_args['num_workers'], drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=test_dataset_args['batch_size'],
                                 shuffle=False, num_workers=test_dataset_args['num_workers'], drop_last=False)

    # for test---- prepare labels
    labels = np.load('./data/frame_labels_' + config['test_dataset_type'] + '.npy')
    if config['test_dataset_type'] == 'shanghai':
        labels = np.expand_dims(labels, 0)
    videos = OrderedDict()
    videos_list = sorted(glob.glob(os.path.join(test_folder, '*')))
    labels_list = []
    label_length = 0
    psnr_list = {}
    for video in sorted(videos_list):
        video_name = video.split('/')[-1]
        videos[video_name] = {}
        videos[video_name]['path'] = video
        videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
        videos[video_name]['frame'].sort()
        videos[video_name]['length'] = len(videos[video_name]['frame'])
        labels_list = np.append(labels_list, labels[0][4 + label_length:videos[video_name]['length'] + label_length])
        label_length += videos[video_name]['length']
        psnr_list[video_name] = []

    # Model setting
    num_unet_layers = 4
    discriminator_num_filters = [128, 256, 512, 512]

    # for gradient loss
    alpha = 1
    # for int loss
    l_num = 2
    pretrain = False

    if config['generator'] == 'cycle_generator_convlstm':
        ngf = 64
        netG = 'resnet_6blocks'
        norm = 'instance'
        no_dropout = False
        init_type = 'normal'
        init_gain = 0.02
        gpu_ids = []
        model = define_G(train_dataset_args['c'], train_dataset_args['c'],
                             ngf, netG, norm, not no_dropout, init_type, init_gain, gpu_ids)
    elif config['generator'] == 'unet':
        # generator = UNet(n_channels=train_dataset_args['c']*(train_dataset_args['t_length']-1),
        #                  layer_nums=num_unet_layers, output_channel=train_dataset_args['c'])
        model = PreAE(train_dataset_args['c'], train_dataset_args['t_length'], **config['model_args'])
    else:
        raise Exception('The generator is not implemented')

    # generator = torch.load('save/avenue_cycle_generator_convlstm_flownet2_0103/generator-epoch-199.pth')
    if config['use_D']:
        discriminator=PixelDiscriminator(train_dataset_args['c'],discriminator_num_filters,use_norm=False)
        optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=0.00002)

    # optimizer setting
    params_encoder = list(model.parameters())
    params_decoder = list(model.parameters())
    params = params_encoder + params_decoder
    optimizer_G, lr_scheduler = utils.make_optimizer(
        params, config['optimizer'], config['optimizer_args'])    


    # set loss, different range with the source version, should change
    lam_int = 1.0 * 2
    lam_gd = 1.0 * 2
    # TODO here we use no flow loss
    # lam_op = 0  # 2.0
    # op_loss = Flow_Loss()
    
    adversarial_loss = Adversarial_Loss()
    # TODO if use adv
    lam_adv = 0.05
    discriminate_loss = Discriminate_Loss()
    alpha = 1
    l_num = 2
    gd_loss = Gradient_Loss(alpha, train_dataset_args['c'])    
    int_loss = Intensity_Loss(l_num)
    object_loss = ObjectLoss(device, l_num)

    # parallel if muti-gpus
    if torch.cuda.is_available():
        model.cuda()
        if config['use_D']:
            discriminator.cuda()
    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if config['use_D']:
            discriminator = nn.DataParallel(discriminator)
    # Training
    utils.log('Start train')
    max_frame_AUC, max_roi_AUC = 0,0
    base_channel_num  = train_dataset_args['c'] * (train_dataset_args['t_length'] - 1)
    save_epoch = 5 if config['save_epoch'] is None else config['save_epoch']
    for epoch in range(config['epochs']):

        model.train()
        for j, (imgs, bbox, flow) in enumerate(tqdm(train_dataloader, desc='train', leave=False)):
            imgs = imgs.cuda()
            flow = flow.cuda()
            # input = imgs[:, :-1, ].view(imgs.shape[0], -1, imgs.shape[-2], imgs.shape[-1])
            input = imgs[:, :-1, ]
            target = imgs[:, -1, ]
            outputs = model(input)

            if config['use_D']:
                g_adv_loss = adversarial_loss(discriminator(outputs))
            else:
                g_adv_loss = 0 

            g_object_loss = object_loss(outputs, target, flow, bbox)
            # g_int_loss = int_loss(outputs, target)
            g_gd_loss = gd_loss(outputs, target)
            g_loss = lam_adv * g_adv_loss + lam_gd * g_gd_loss + lam_int * g_object_loss

            optimizer_G.zero_grad()
            g_loss.backward()

            optimizer_G.step()

            train_psnr = utils.psnr_error(outputs,target)

            # ----------- update optim_D -------
            if config['use_D']:
                optimizer_D.zero_grad()
                d_loss = discriminate_loss(discriminator(target), discriminator(outputs.detach()))
                d_loss.backward()
                optimizer_D.step()
        lr_scheduler.step()

        utils.log('----------------------------------------')
        utils.log('Epoch:' + str(epoch + 1))
        utils.log('----------------------------------------')
        utils.log('Loss: Reconstruction {:.6f}'.format(g_loss.item()))

        # Testing
        utils.log('Evaluation of ' + config['test_dataset_type'])   


        # Save the model
        if epoch % save_epoch == 0 or epoch == config['epochs'] - 1:
            if not os.path.exists(save_path):
                os.makedirs(save_path) 
            if not os.path.exists(os.path.join(save_path, "models")):
                os.makedirs(os.path.join(save_path, "models")) 
            # TODO 
            frame_AUC = ObjectLoss_evaluate(test_dataloader, model, labels_list, videos, dataset=config['test_dataset_type'],device = device,
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'],
                is_visual=False, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./final"), labels_dict=labels) 
            
            torch.save(model.state_dict(), os.path.join(save_path, 'models/model-epoch-{}.pth'.format(epoch)))
            if config['use_D']:
                torch.save(discriminator.state_dict(), os.path.join(save_path, 'models/discrominator-epoch-{}.pth'.format(epoch)))
        else:
            frame_AUC = ObjectLoss_evaluate(test_dataloader, model, labels_list, videos, dataset=config['test_dataset_type'],device=device,
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w']) 

        utils.log('The result of ' + config['test_dataset_type'])
        utils.log("AUC: {}%".format(frame_AUC*100))

        if frame_AUC > max_frame_AUC:
            max_frame_AUC = frame_AUC
            # TODO
            torch.save(model.state_dict(), os.path.join(save_path, 'models/max-frame_auc-model.pth'))
            if config['use_D']:
                torch.save(discriminator.state_dict(), os.path.join(save_path, 'models/discrominator-epoch-{}.pth'.format(epoch)))
            # evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
            #     frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'], 
            #     is_visual=True, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./frame_best"), labels_dict=labels) 
        
        utils.log('----------------------------------------')

    utils.log('Training is finished')
    utils.log('max_frame_AUC: {}'.format(max_frame_AUC))
Beispiel #10
0
    def train_model():
        # build model
        model = build_model(hparams, **dataset.preprocessing.kwargs)

        # compile model
        model.compile(optimizer=utils.make_optimizer(hparams.optimizer,
                                                     hparams.opt_param),
                      loss="categorical_crossentropy",
                      metrics=["accuracy"])

        # print summary of created model
        lq.models.summary(model)

        # if model already exists, load it and continue training
        initial_epoch = 0
        if os.path.exists(os.path.join(model_dir, "stats.json")):
            with open(os.path.join(model_dir, "stats.json"),
                      "r") as stats_file:
                model_path = os.path.join(model_dir, "weights.h5")
                initial_epoch = json.load(stats_file)["epoch"]
                click.echo(
                    f"Restoring model from {model_path} at epoch = {initial_epoch}"
                )
                model.load_weights(model_path)

        # attach callbacks
        # save model at the end of each epoch
        training_callbacks = [callbacks.SaveStats(model_dir=model_dir)]
        # compute MI
        mi_estimator = callbacks.EstimateMI(dataset,
                                            hparams.mi_layer_types,
                                            log_file=os.path.join(
                                                model_dir, "mi_data.json"))
        training_callbacks.extend([mi_estimator])
        # custom prgress bar
        training_callbacks.extend(
            [callbacks.ProgressBar(initial_epoch, ["accuracy"])])

        # train the model
        train_log = model.fit(
            dataset.train_data(hparams.batch_size),
            epochs=hparams.epochs,
            steps_per_epoch=dataset.train_examples // hparams.batch_size,
            validation_data=dataset.validation_data(hparams.batch_size),
            validation_steps=dataset.validation_examples // hparams.batch_size,
            initial_epoch=initial_epoch,
            callbacks=training_callbacks,
            verbose=0)

        # # ==================================================================================
        # import numpy as np
        # import matplotlib.pyplot as plt
        # mi_data = mi_estimator.mi_data
        # sm = plt.cm.ScalarMappable(cmap='gnuplot', norm=plt.Normalize(vmin=0, vmax=config['epochs']))
        # sm._A = []

        # # plot infoplane evolution
        # n_layer_types = len(mi_data)
        # fig, ax = plt.subplots(nrows=1, ncols=n_layer_types, figsize=(5*n_layer_types, 5))
        # if n_layer_types == 1:
        #     ax = [ax]
        # ax = dict(zip(mi_data.keys(), ax))
        # for layer_type, layer_data in mi_data.items():
        #     for layer_name, mi_values in layer_data.items():
        #         c = [sm.to_rgba(int(epoch)) for epoch in mi_values.keys()]

        #         mi = np.stack([mi_val for (_, mi_val) in mi_values.items()])
        #         ax[layer_type].scatter(mi[:,0], mi[:,1], c=c)

        #     epochs = list(layer_data[next(iter(layer_data))].keys())
        #     for epoch_idx in epochs:
        #         x_data = []
        #         y_data = []
        #         for layer_name, mi_values in layer_data.items():
        #             x_data.append(mi_values[epoch_idx][0])
        #             y_data.append(mi_values[epoch_idx][1])
        #         ax[layer_type].plot(x_data, y_data, c='k', alpha=0.1)

        #     ax[layer_type].set_title(layer_type)
        #     ax[layer_type].grid()

        # cbaxes = fig.add_axes([1.0, 0.10, 0.05, 0.85])
        # plt.colorbar(sm, label='Epoch', cax=cbaxes)
        # plt.tight_layout()

        # # plot layerwise
        # for layer_type, layer_data in  mi_data.items():
        #     if layer_data:
        #         n_layers = len(layer_data)
        #         fig, ax = plt.subplots(nrows=1, ncols=n_layers, figsize=(3*n_layers, 3))
        #         ax = dict(zip(layer_data.keys(), ax))
        #         for (layer_name, mi_values) in  layer_data.items():
        #             c = [sm.to_rgba(int(epoch)) for epoch in mi_values.keys()]

        #             mi = np.stack([mi_val for (_, mi_val) in mi_values.items()])
        #             ax[layer_name].scatter(mi[:,0], mi[:,1], c=c)
        #             ax[layer_name].set_title(layer_name)
        #             ax[layer_name].set_xlabel("I(T;X)")
        #             ax[layer_name].set_ylabel("I(T;Y)")
        #             ax[layer_name].grid()
        #         cbaxes = fig.add_axes([1.0, 0.1, 0.01, 0.80])
        #         plt.colorbar(sm, label='Epoch', cax=cbaxes)
        #         plt.tight_layout()

        # plt.show()
        # # ==================================================================================

        return train_log
Beispiel #11
0
    print("=======================================================")

    cfg = options.get_arguments()

    EXPERIMENT = f"{cfg.model}_{cfg.experiment}"
    MODEL_PATH = f"models/{EXPERIMENT}"
    LOG_PATH = f"logs/{EXPERIMENT}"

    utils.make_folder(MODEL_PATH)
    utils.make_folder(LOG_PATH)

    criterions = utils.define_losses()
    dataloaders = utils.make_data_novel(cfg)

    model = utils.build_structure_generator(cfg).to(cfg.device)
    optimizer = utils.make_optimizer(cfg, model)
    scheduler = utils.make_lr_scheduler(cfg, optimizer)

    logger = utils.make_logger(LOG_PATH)
    writer = utils.make_summary_writer(EXPERIMENT)

    def on_after_epoch(model, df_hist, images, epoch, saveEpoch):
        utils.save_best_model(MODEL_PATH, model, df_hist)
        utils.checkpoint_model(MODEL_PATH, model, epoch, saveEpoch)
        utils.log_hist(logger, df_hist)
        utils.write_on_board_losses_stg2(writer, df_hist)
        utils.write_on_board_images_stg2(writer, images, epoch)

    if cfg.lrSched is not None:
        def on_after_batch(iteration):
            utils.write_on_board_lr(writer, scheduler.get_lr(), iteration)
Beispiel #12
0
    log_dir = os.path.join(os.path.expanduser('./log'), time_str)
    if not os.path.isdir(
            log_dir):  # Create the log directory if it doesn't exist
        os.makedirs(log_dir)
    set_logger(logger, log_dir)
    log_file = os.path.join(log_dir, opt.version + '.txt')
    with open(log_file, 'a') as f:
        f.write(str(opt) + '\n')
        f.flush()

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus
    cudnn.benchmark = True

    data = Data()
    model = build_model(opt, data.num_classes)
    optimizer = make_optimizer(opt, model)
    loss = make_loss(opt, data.num_classes)

    # WARMUP_FACTOR: 0.01
    # WARMUP_ITERS: 10
    scheduler = WarmupMultiStepLR(optimizer, opt.steps, 0.1, 0.01, 10,
                                  "linear")
    main = Main(opt, model, data, optimizer, scheduler, loss)

    if opt.mode == 'train':

        # 总迭代次数
        epoch = 200
        start_epoch = 1

        # 断点加载训练
Beispiel #13
0
def train_test(Xtrain, Ytrain, Xtest, Ytest, paras, outpath):
    train_no, x_dim = Xtrain.shape
    try:
        test_no, y_dim = Ytest.shape
    except:
        test_no = Ytest.shape
        y_dim = 1

    hypers = {
        "x_dim": x_dim,
        "y_dim": y_dim,
        "hidden_dims": paras["hidden_dims"],
        "nonlinearity": "relu",
        "adapter": {
            'in': paras['in'],
            'out': paras['out']
        },
        "method": "bayes",
        "style": "heteroskedastic",
        "homo_logvar_scale": 2 * np.log(0.2),
        "prior_type": ["empirical", "wider_he", "wider_he"],
        "n_epochs": paras['n_epochs'],
        # "batch_size": 32,
        "batch_size": train_no,
        "learning_rate": paras['learning_rate'],
        "lambda": 1.0,
        "warmup_updates": {
            'lambda': 14000.0
        },
        "anneal_updates": {
            'lambda': 1000.0
        },
        "optimizer": "adam",
        "gradient_clip": 0.1,
        "data_fraction": 1.0,
        "sections_to_run": ["train", 'test']
    }

    data = [[Xtrain, Ytrain.reshape(-1)], [Xtest, Ytest.reshape(-1)]]

    restricted_training_set = u.restrict_dataset_size(data[0],
                                                      hypers['data_fraction'])
    hypers['dataset_size'] = len(restricted_training_set[0])

    device_id = 1
    device_string = u.get_device_string(device_id)
    print(hypers)
    with tf.device(device_string):
        if True:
            model_and_metrics = make_model(hypers)

            train_op = u.make_optimizer(model_and_metrics, hypers)
            sess = u.get_session()
            saver = tf.train.Saver()

            all_summaries = []
            best_valid_accuracy = np.inf

        for epoch in range(1, hypers['n_epochs'] + 1):
            verbose = (epoch % 20 == 0)
            if verbose:
                print("Epoch %i:        " % epoch, end='')

            epoch_summary, accuracies = u.train_valid_test(
                {
                    'train': restricted_training_set,
                    'test': data[1]
                }, sess, model_and_metrics, train_op, hypers, verbose)
            # dump log file
            all_summaries.append(epoch_summary)

            if epoch % 5000 == 0:
                saver.save(sess,
                           os.path.join(outpath, 'model.ckpt'),
                           global_step=epoch)

        with open(os.path.join(outpath, "summaries.json"), 'w') as f:
            json.dump(all_summaries, f, indent=4, cls=u.NumpyEncoder)

    return None
Beispiel #14
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta_{}-{}shot'.format(config['train_dataset'],
                                         config['n_shot'])
        svname += '_' + config['model']
        if config['model_args'].get('encoder'):
            svname += '-' + config['model_args']['encoder']
        if config['model_args'].get('prog_synthesis'):
            svname += '-' + config['model_args']['prog_synthesis']
    svname += '-seed' + str(args.seed)
    if args.tag is not None:
        svname += '_' + args.tag

    save_path = os.path.join(args.save_dir, svname)
    utils.ensure_path(save_path, remove=False)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"),
                          file_mode="a+",
                          should_flush=True)

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']

    if config.get('n_train_way') is not None:
        n_train_way = config['n_train_way']
    else:
        n_train_way = n_way
    if config.get('n_train_shot') is not None:
        n_train_shot = config['n_train_shot']
    else:
        n_train_shot = n_shot
    if config.get('ep_per_batch') is not None:
        ep_per_batch = config['ep_per_batch']
    else:
        ep_per_batch = 1

    random_state = np.random.RandomState(args.seed)
    print('seed:', args.seed)

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    utils.log('train dataset: {} (x{})'.format(train_dataset[0][0].shape,
                                               len(train_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)
    train_sampler = BongardSampler(train_dataset.n_tasks,
                                   config['train_batches'], ep_per_batch,
                                   random_state.randint(2**31))
    train_loader = DataLoader(train_dataset,
                              batch_sampler=train_sampler,
                              num_workers=8,
                              pin_memory=True)

    # tvals
    tval_loaders = {}
    tval_name_ntasks_dict = {
        'tval': 2000,
        'tval_ff': 600,
        'tval_bd': 480,
        'tval_hd_comb': 400,
        'tval_hd_novel': 320
    }  # numbers depend on dataset
    for tval_type in tval_name_ntasks_dict.keys():
        if config.get('{}_dataset'.format(tval_type)):
            tval_dataset = datasets.make(
                config['{}_dataset'.format(tval_type)],
                **config['{}_dataset_args'.format(tval_type)])
            utils.log('{} dataset: {} (x{})'.format(tval_type,
                                                    tval_dataset[0][0].shape,
                                                    len(tval_dataset)))
            if config.get('visualize_datasets'):
                utils.visualize_dataset(tval_dataset, 'tval_ff_dataset',
                                        writer)
            tval_sampler = BongardSampler(
                tval_dataset.n_tasks,
                n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch,
                ep_per_batch=ep_per_batch,
                seed=random_state.randint(2**31))
            tval_loader = DataLoader(tval_dataset,
                                     batch_sampler=tval_sampler,
                                     num_workers=8,
                                     pin_memory=True)
            tval_loaders.update({tval_type: tval_loader})
        else:
            tval_loaders.update({tval_type: None})

    # val
    val_dataset = datasets.make(config['val_dataset'],
                                **config['val_dataset_args'])
    utils.log('val dataset: {} (x{})'.format(val_dataset[0][0].shape,
                                             len(val_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    val_sampler = BongardSampler(val_dataset.n_tasks,
                                 n_batch=900 // ep_per_batch,
                                 ep_per_batch=ep_per_batch,
                                 seed=random_state.randint(2**31))
    val_loader = DataLoader(val_dataset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)

    ########

    #### Model and optimizer ####

    if config.get('load'):
        print('loading pretrained model: ', config['load'])
        model = models.load(torch.load(config['load']))
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            print('loading pretrained encoder: ', config['load_encoder'])
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

        if config.get('load_prog_synthesis'):
            print('loading pretrained program synthesis model: ',
                  config['load_prog_synthesis'])
            prog_synthesis = models.load(
                torch.load(config['load_prog_synthesis']))
            model.prog_synthesis.load_state_dict(prog_synthesis.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'vl', 'va']
    tval_tuple_lst = []
    for k, v in tval_loaders.items():
        if v is not None:
            loss_key = 'tvl' + k.split('tval')[-1]
            acc_key = ' tva' + k.split('tval')[-1]
            aves_keys.append(loss_key)
            aves_keys.append(acc_key)
            tval_tuple_lst.append((k, v, loss_key, acc_key))

    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):

            x_shot, x_query = fs.split_shot_query(data.cuda(),
                                                  n_train_way,
                                                  n_train_shot,
                                                  n_query,
                                                  ep_per_batch=ep_per_batch)
            label_query = fs.make_nk_label(n_train_way,
                                           n_query,
                                           ep_per_batch=ep_per_batch).cuda()

            if config['model'] == 'snail':  # only use one selected label_query
                query_dix = random_state.randint(n_train_way * n_query)
                label_query = label_query.view(ep_per_batch, -1)[:, query_dix]
                x_query = x_query[:, query_dix:query_dix + 1]

            if config['model'] == 'maml':  # need grad in maml
                model.zero_grad()

            logits = model(x_shot, x_query).view(-1, n_train_way)
            loss = F.cross_entropy(logits, label_query)
            acc = utils.compute_acc(logits, label_query)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        model.eval()

        for name, loader, name_l, name_a in [('val', val_loader, 'vl', 'va')
                                             ] + tval_tuple_lst:

            if config.get('{}_dataset'.format(name)) is None:
                continue

            np.random.seed(0)
            for data, _ in tqdm(loader, desc=name, leave=False):
                x_shot, x_query = fs.split_shot_query(
                    data.cuda(),
                    n_way,
                    n_shot,
                    n_query,
                    ep_per_batch=ep_per_batch)
                label_query = fs.make_nk_label(
                    n_way, n_query, ep_per_batch=ep_per_batch).cuda()

                if config[
                        'model'] == 'snail':  # only use one randomly selected label_query
                    query_dix = random_state.randint(n_train_way)
                    label_query = label_query.view(ep_per_batch, -1)[:,
                                                                     query_dix]
                    x_query = x_query[:, query_dix:query_dix + 1]

                if config['model'] == 'maml':  # need grad in maml
                    model.zero_grad()
                    logits = model(x_shot, x_query, eval=True).view(-1, n_way)
                    loss = F.cross_entropy(logits, label_query)
                    acc = utils.compute_acc(logits, label_query)
                else:
                    with torch.no_grad():
                        logits = model(x_shot, x_query,
                                       eval=True).view(-1, n_way)
                        loss = F.cross_entropy(logits, label_query)
                        acc = utils.compute_acc(logits, label_query)

                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        log_str = 'epoch {}, train {:.4f}|{:.4f}, val {:.4f}|{:.4f}'.format(
            epoch, aves['tl'], aves['ta'], aves['vl'], aves['va'])
        for tval_name, _, loss_key, acc_key in tval_tuple_lst:
            log_str += ', {} {:.4f}|{:.4f}'.format(tval_name, aves[loss_key],
                                                   aves[acc_key])
            writer.add_scalars('loss', {tval_name: aves[loss_key]}, epoch)
            writer.add_scalars('acc', {tval_name: aves[acc_key]}, epoch)
        log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        utils.log(log_str)

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                       os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()

    print('finished training!')
    logger.close()
Beispiel #15
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta'
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    if args.dataset == 'all':
        train_lst = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd',
                'quickdraw', 'fungi', 'vgg_flower']
        eval_lst = ['ilsvrc_2012']
    else:
        train_lst = [args.dataset]
        eval_lst = [args.dataset]

    if config.get('no_train') == True:
        train_iter = None
    else:
        trainset = make_md(train_lst, 'episodic', split='train', image_size=126)
        train_iter = trainset.make_one_shot_iterator().get_next()

    if config.get('no_val') == True:
        val_iter = None
    else:
        valset = make_md(eval_lst, 'episodic', split='val', image_size=126)
        val_iter = valset.make_one_shot_iterator().get_next()

    testset = make_md(eval_lst, 'episodic', split='test', image_size=126)
    test_iter = testset.make_one_shot_iterator().get_next()

    sess = tf.Session()

    ########

    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    def process_data(e):
        e = list(e[0])
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(146),
            transforms.CenterCrop(128),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
        ])
        for ii in [0, 3]:
            e[ii] = ((e[ii] + 1.0) * 0.5 * 255).astype('uint8')
            tmp = torch.zeros(len(e[ii]), 3, 128, 128).float()
            for i in range(len(e[ii])):
                tmp[i] = transform(e[ii][i])
            e[ii] = tmp.cuda()

        e[1] = torch.from_numpy(e[1]).long().cuda()
        e[4] = torch.from_numpy(e[4]).long().cuda()

        return e

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model) 
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        if config.get('no_train') == True:
            pass
        else:
            for i_ep in tqdm(range(config['n_train'])):

                e = process_data(sess.run(train_iter))
                loss, acc = model(e[0], e[1], e[3], e[4])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                aves['tl'].add(loss.item())
                aves['ta'].add(acc)

                loss = None 

        # eval
        model.eval()

        for name, ds_iter, name_l, name_a in [
                ('tval', val_iter, 'tvl', 'tva'),
                ('val', test_iter, 'vl', 'va')]:
            if config.get('no_val') == True and name == 'tval':
                continue

            for i_ep in tqdm(range(config['n_eval'])):

                e = process_data(sess.run(ds_iter))

                with torch.no_grad():
                    loss, acc = model(e[0], e[1], e[3], e[4])
                
                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        _sig = 0

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                'val {:.4f}|{:.4f}, {} {}/{} (@{})'.format(
                epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
                aves['vl'], aves['va'], t_epoch, t_used, t_estimate, _sig))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
Beispiel #16
0
    def __init__(self, cfg, model, train_dl, val_dl, loss_func, num_gpus,
                 device):

        self.cfg = cfg
        self.model = model
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.loss_func = loss_func

        self.loss_avg = AvgerageMeter()
        self.acc_avg = AvgerageMeter()
        self.f1_avg = AvgerageMeter()

        self.val_loss_avg = AvgerageMeter()
        self.val_acc_avg = AvgerageMeter()
        self.device = device

        self.train_epoch = 1

        if cfg.SOLVER.USE_WARMUP:
            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR * 0.1,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
        else:
            self.optim = make_optimizer(
                self.model,
                opt=self.cfg.SOLVER.OPTIMIZER_NAME,
                lr=cfg.SOLVER.BASE_LR,
                weight_decay=self.cfg.SOLVER.WEIGHT_DECAY,
                momentum=0.9)
        if cfg.SOLVER.RESUME:
            print("Resume from checkpoint...")
            checkpoint = torch.load(cfg.SOLVER.RESUME_CHECKPOINT)
            param_dict = checkpoint['model_state_dict']
            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            for state in self.optim.state.values():
                for k, v in state.items():
                    print(type(v))
                    if torch.is_tensor(v):
                        state[k] = v.to(self.device)
            self.train_epoch = checkpoint['epoch'] + 1
            for i in param_dict:
                if i.startswith("module"):
                    new_i = i[7:]
                else:
                    new_i = i
                if 'classifier' in i or 'fc' in i:
                    continue
                self.model.state_dict()[new_i].copy_(param_dict[i])

        self.batch_cnt = 0

        self.logger = logging.getLogger('baseline.train')
        self.log_period = cfg.SOLVER.LOG_PERIOD
        self.checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        self.eval_period = cfg.SOLVER.EVAL_PERIOD
        self.output_dir = cfg.OUTPUT_DIR

        self.epochs = cfg.SOLVER.MAX_EPOCHS

        if cfg.SOLVER.TENSORBOARD.USE:
            summary_dir = os.path.join(cfg.OUTPUT_DIR, 'summaries/')
            os.makedirs(summary_dir, exist_ok=True)
            self.summary_writer = SummaryWriter(log_dir=summary_dir)
        self.current_iteration = 0

        self.logger.info(self.model)

        if self.cfg.SOLVER.USE_WARMUP:

            scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optim, self.epochs, eta_min=cfg.SOLVER.MIN_LR)
            self.scheduler = GradualWarmupScheduler(
                self.optim,
                multiplier=10,
                total_epoch=cfg.SOLVER.WARMUP_EPOCH,
                after_scheduler=scheduler_cosine)
            # self.scheduler = WarmupMultiStepLR(self.optim, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
            #                                cfg.SOLVER.WARMUP_EPOCH, cfg.SOLVER.WARMUP_METHOD)
        else:
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optim, self.epochs, eta_min=cfg.SOLVER.MIN_LR)

        if num_gpus > 1:

            self.logger.info(self.optim)
            self.model = nn.DataParallel(self.model)
            if cfg.SOLVER.SYNCBN:
                self.model = convert_model(self.model)
                self.model = self.model.to(device)
                self.logger.info(
                    'More than one gpu used, convert model to use SyncBN.')
                self.logger.info('Using pytorch SyncBN implementation')
                self.logger.info(self.model)

            self.logger.info('Trainer Built')

            return

        else:
            self.model = self.model.to(device)
            self.logger.info('Cpu used.')
            self.logger.info(self.model)
            self.logger.info('Trainer Built')

            return
Beispiel #17
0
def main():
    runId = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR, runId)
    if not os.path.exists(cfg.OUTPUT_DIR):
        os.mkdir(cfg.OUTPUT_DIR)
    print(cfg.OUTPUT_DIR)
    torch.manual_seed(cfg.RANDOM_SEED)
    random.seed(cfg.RANDOM_SEED)
    np.random.seed(cfg.RANDOM_SEED)
    os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID

    use_gpu = torch.cuda.is_available() and cfg.MODEL.DEVICE == "cuda"
    if not cfg.EVALUATE_ONLY:
        sys.stdout = Logger(osp.join(cfg.OUTPUT_DIR, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(cfg.OUTPUT_DIR, 'log_test.txt'))

    print("==========\nConfigs:{}\n==========".format(cfg))

    if use_gpu:
        print("Currently using GPU {}".format(cfg.MODEL.DEVICE_ID))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(cfg.RANDOM_SEED)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(cfg.DATASETS.NAME))

    dataset = data_manager.init_dataset(root=cfg.DATASETS.ROOT_DIR,
                                        name=cfg.DATASETS.NAME)
    print("Initializing model: {}".format(cfg.MODEL.NAME))

    if cfg.MODEL.ARCH == 'video_baseline':
        torch.backends.cudnn.benchmark = False
        model = models.init_model(name=cfg.MODEL.ARCH,
                                  num_classes=625,
                                  pretrain_choice=cfg.MODEL.PRETRAIN_CHOICE,
                                  last_stride=cfg.MODEL.LAST_STRIDE,
                                  neck=cfg.MODEL.NECK,
                                  model_name=cfg.MODEL.NAME,
                                  neck_feat=cfg.TEST.NECK_FEAT,
                                  model_path=cfg.MODEL.PRETRAIN_PATH)

    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    transform_train = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TRAIN),
        T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
        T.Pad(cfg.INPUT.PADDING),
        T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.RandomErasing(probability=cfg.INPUT.RE_PROB,
                        mean=cfg.INPUT.PIXEL_MEAN)
    ])
    transform_test = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TEST),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    pin_memory = True if use_gpu else False

    cfg.DATALOADER.NUM_WORKERS = 0

    trainloader = DataLoader(VideoDataset(
        dataset.train,
        seq_len=cfg.DATASETS.SEQ_LEN,
        sample=cfg.DATASETS.TRAIN_SAMPLE_METHOD,
        transform=transform_train,
        dataset_name=cfg.DATASETS.NAME),
                             sampler=RandomIdentitySampler(
                                 dataset.train,
                                 num_instances=cfg.DATALOADER.NUM_INSTANCE),
                             batch_size=cfg.SOLVER.SEQS_PER_BATCH,
                             num_workers=cfg.DATALOADER.NUM_WORKERS,
                             pin_memory=pin_memory,
                             drop_last=True)

    queryloader = DataLoader(VideoDataset(
        dataset.query,
        seq_len=cfg.DATASETS.SEQ_LEN,
        sample=cfg.DATASETS.TEST_SAMPLE_METHOD,
        transform=transform_test,
        max_seq_len=cfg.DATASETS.TEST_MAX_SEQ_NUM,
        dataset_name=cfg.DATASETS.NAME),
                             batch_size=cfg.TEST.SEQS_PER_BATCH,
                             shuffle=False,
                             num_workers=cfg.DATALOADER.NUM_WORKERS,
                             pin_memory=pin_memory,
                             drop_last=False)

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     seq_len=cfg.DATASETS.SEQ_LEN,
                     sample=cfg.DATASETS.TEST_SAMPLE_METHOD,
                     transform=transform_test,
                     max_seq_len=cfg.DATASETS.TEST_MAX_SEQ_NUM,
                     dataset_name=cfg.DATASETS.NAME),
        batch_size=cfg.TEST.SEQS_PER_BATCH,
        shuffle=False,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        pin_memory=pin_memory,
        drop_last=False,
    )

    if cfg.MODEL.SYN_BN:
        if use_gpu:
            model = nn.DataParallel(model)
        if cfg.SOLVER.FP_16:
            model = apex.parallel.convert_syncbn_model(model)
        model.cuda()

    start_time = time.time()
    xent = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids)
    tent = TripletLoss(cfg.SOLVER.MARGIN)

    optimizer = make_optimizer(cfg, model)

    scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                  cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
                                  cfg.SOLVER.WARMUP_ITERS,
                                  cfg.SOLVER.WARMUP_METHOD)
    # metrics = test(model, queryloader, galleryloader, cfg.TEST.TEMPORAL_POOL_METHOD, use_gpu)
    no_rise = 0
    best_rank1 = 0
    start_epoch = 0
    for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS):
        # if no_rise == 10:
        #     break
        scheduler.step()
        print("noriase:", no_rise)
        print("==> Epoch {}/{}".format(epoch + 1, cfg.SOLVER.MAX_EPOCHS))
        print("current lr:", scheduler.get_lr()[0])

        train(model, trainloader, xent, tent, optimizer, use_gpu)
        if cfg.SOLVER.EVAL_PERIOD > 0 and (
            (epoch + 1) % cfg.SOLVER.EVAL_PERIOD == 0 or
            (epoch + 1) == cfg.SOLVER.MAX_EPOCHS):
            print("==> Test")

            metrics = test(model, queryloader, galleryloader,
                           cfg.TEST.TEMPORAL_POOL_METHOD, use_gpu)
            rank1 = metrics[0]
            if rank1 > best_rank1:
                best_rank1 = rank1
                no_rise = 0
            else:
                no_rise += 1
                continue

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            torch.save(
                state_dict,
                osp.join(
                    cfg.OUTPUT_DIR, "rank1_" + str(rank1) + '_checkpoint_ep' +
                    str(epoch + 1) + '.pth'))
            # best_p = osp.join(cfg.OUTPUT_DIR, "rank1_" + str(rank1) + '_checkpoint_ep' + str(epoch + 1) + '.pth')

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
Beispiel #18
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'pretrain-multi'
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    def make_dataset(name):
        dataset = make_md([name],
            'batch', split='train', image_size=126, batch_size=256)
        return dataset

    ds_names = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', \
            'quickdraw', 'fungi', 'vgg_flower']
    datasets = []
    for name in ds_names:
        datasets.append(make_dataset(name))
    iters = []
    for d in datasets:
        iters.append(d.make_one_shot_iterator().get_next())

    to_torch_labels = lambda a: torch.from_numpy(a).long()

    to_pil = transforms.ToPILImage()
    augmentation = transforms.Compose([
        transforms.Resize(146),
        transforms.RandomResizedCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    ])
    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        n_batch = 915547 // 256
        with tf.Session() as sess:
            for i_batch in tqdm(range(n_batch)):
                if random.randint(0, 1) == 0:
                    ds_id = 0
                else:
                    ds_id = random.randint(1, len(datasets) - 1)

                next_element = iters[ds_id]
                e, cfr_id = sess.run(next_element)

                data_, label = e[0], to_torch_labels(e[1])
                data_ = ((data_ + 1.0) * 0.5 * 255).astype('uint8')
                data = torch.zeros(256, 3, 128, 128).float()
                for i in range(len(data_)):
                    x = data_[i]
                    x = to_pil(x)
                    x = augmentation(x)
                    data[i] = x

                data = data.cuda()
                label = label.cuda()

                logits = model(data, cfr_id=ds_id)
                loss = F.cross_entropy(logits, label)
                acc = utils.compute_acc(logits, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                aves['tl'].add(loss.item())
                aves['ta'].add(acc)

                logits = None; loss = None

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
                epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(save_obj, os.path.join(
                    save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()
Beispiel #19
0
tf.summary.image('Y', real_Y)
tf.summary.image('X/generated', fake_Y)
tf.summary.image('X/reconstruction', cyc_X)
tf.summary.image('Y/generated', fake_X)
tf.summary.image('Y/reconstruction', cyc_Y)

merged = tf.summary.merge_all()

model_vars = tf.trainable_variables()

D_X_vars = [var for var in model_vars if 'D_X' in var.name]
G_vars = [var for var in model_vars if 'G' in var.name]
D_Y_vars = [var for var in model_vars if 'D_Y' in var.name]
F_vars = [var for var in model_vars if 'F' in var.name]

G_optimizer = make_optimizer(G_loss, G_vars, learning_rate, beta1, name='Adam_G')
D_Y_optimizer = make_optimizer(D_Y_loss, D_Y_vars, learning_rate, beta1, name='Adam_D_Y')
F_optimizer = make_optimizer(F_loss, F_vars, learning_rate, beta1, name='Adam_F')
D_X_optimizer = make_optimizer(D_X_loss, D_X_vars, learning_rate, beta1, name='Adam_D_X')


saver = tf.train.Saver()

with tf.Session() as sess:
    print('---Create Session---')
    summary_writer = tf.summary.FileWriter(MODEL_SAVE_PATH, sess.graph)
    step = 0
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run([iterator_X.initializer, iterator_Y.initializer])
    # sess.run()
    while True:
Beispiel #20
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'meta_{}-{}shot'.format(
                config['train_dataset'], config['n_shot'])
        svname += '_' + config['model'] + '-' + config['model_args']['encoder']
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']

    if config.get('n_train_way') is not None:
        n_train_way = config['n_train_way']
    else:
        n_train_way = n_way
    if config.get('n_train_shot') is not None:
        n_train_shot = config['n_train_shot']
    else:
        n_train_shot = n_shot
    if config.get('ep_per_batch') is not None:
        ep_per_batch = config['ep_per_batch']
    else:
        ep_per_batch = 1

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    utils.log('train dataset: {} (x{}), {}'.format(
            train_dataset[0][0].shape, len(train_dataset),
            train_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)
    train_sampler = CategoriesSampler(
            train_dataset.label, config['train_batches'],
            n_train_way, n_train_shot + n_query,
            ep_per_batch=ep_per_batch)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                              num_workers=8, pin_memory=True)

    # tval
    if config.get('tval_dataset'):
        tval_dataset = datasets.make(config['tval_dataset'],
                                     **config['tval_dataset_args'])
        utils.log('tval dataset: {} (x{}), {}'.format(
                tval_dataset[0][0].shape, len(tval_dataset),
                tval_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(tval_dataset, 'tval_dataset', writer)
        tval_sampler = CategoriesSampler(
                tval_dataset.label, 200,
                n_way, n_shot + n_query,
                ep_per_batch=4)
        tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler,
                                 num_workers=8, pin_memory=True)
    else:
        tval_loader = None

    # val
    val_dataset = datasets.make(config['val_dataset'],
                                **config['val_dataset_args'])
    utils.log('val dataset: {} (x{}), {}'.format(
            val_dataset[0][0].shape, len(val_dataset),
            val_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    val_sampler = CategoriesSampler(
            val_dataset.label, 200,
            n_way, n_shot + n_query,
            ep_per_batch=4)
    val_loader = DataLoader(val_dataset, batch_sampler=val_sampler,
                            num_workers=8, pin_memory=True)

    ########

    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
            model.parameters(),
            config['optimizer'], **config['optimizer_args'])

    ########
    
    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model) 
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        np.random.seed(epoch)
        for data, _ in tqdm(train_loader, desc='train', leave=False):
            x_shot, x_query = fs.split_shot_query(
                    data.cuda(), n_train_way, n_train_shot, n_query,
                    ep_per_batch=ep_per_batch)
            label = fs.make_nk_label(n_train_way, n_query,
                    ep_per_batch=ep_per_batch).cuda()

            logits = model(x_shot, x_query).view(-1, n_train_way)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None; loss = None 

        # eval
        model.eval()

        for name, loader, name_l, name_a in [
                ('tval', tval_loader, 'tvl', 'tva'),
                ('val', val_loader, 'vl', 'va')]:

            if (config.get('tval_dataset') is None) and name == 'tval':
                continue

            np.random.seed(0)
            for data, _ in tqdm(loader, desc=name, leave=False):
                x_shot, x_query = fs.split_shot_query(
                        data.cuda(), n_way, n_shot, n_query,
                        ep_per_batch=4)
                label = fs.make_nk_label(n_way, n_query,
                        ep_per_batch=4).cuda()

                with torch.no_grad():
                    logits = model(x_shot, x_query).view(-1, n_way)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)
                
                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        _sig = int(_[-1])

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                'val {:.4f}|{:.4f}, {} {}/{} (@{})'.format(
                epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
                aves['vl'], aves['va'], t_epoch, t_used, t_estimate, _sig))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
Beispiel #21
0
def main():
    model = get_model()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    loader = data.Data(args, "train").train_loader
    val_loader = data.Data(args, "valid").valid_loader
    rank = torch.Tensor([i for i in range(101)]).to(device)
    best_mae = 10000
    for i in range(args.epochs):
        lr = 0.001 if i < 30 else 0.0001
        optimizer = utils.make_optimizer(args, model, lr)
        model.train()
        print('Learning rate:{}'.format(lr))
        start_time = time.time()
        for j, inputs in enumerate(loader):
            img, label, age = inputs
            img = img.to(device)
            label = label.to(device)
            age = age.to(device)
            optimizer.zero_grad()
            outputs = model(img)
            ages = torch.sum(outputs * rank, dim=1)
            loss1 = loss.kl_loss(outputs, label)
            loss2 = loss.L1_loss(ages, age)
            total_loss = loss1 + loss2
            total_loss.backward()
            optimizer.step()
            current_time = time.time()
            print('[Epoch:{}] \t[batch:{}]\t[loss={:.4f}]'.format(
                i, j, total_loss.item()))
        torch.cuda.empty_cache()
        model.eval()
        count = 0
        error = 0
        total_loss = 0
        with torch.no_grad():
            for inputs in val_loader:
                img, label, age = inputs
                count += len(age)
                img = img.to(device)
                label = label.to(device)
                age = age.to(device)
                outputs = model(img)
                ages = torch.sum(outputs * rank, dim=1)
                loss1 = loss.kl_loss(outputs, label)
                loss2 = loss.L1_loss(ages, age)
                total_loss += loss1 + loss2
                error += torch.sum(abs(ages - age))
        mae = error / count
        if mae < best_mae:
            print(
                "Epoch: {}\tVal loss: {:.5f}\tVal MAE: {:.4f} improved from {:.4f}"
                .format(i, total_loss / count, mae, best_mae))
            best_mae = mae
            torch.save(
                model,
                "checkpoint/epoch{:03d}_{}_{:.5f}_{:.4f}_{}_{}.pth".format(
                    i, args.dataset, total_loss / count, best_mae,
                    datetime.now().strftime("%Y%m%d"), args.model_name))
        else:
            print(
                "Epoch: {}\tVal loss: {:.5f}\tBest Val MAE: {:.4f} not improved, current MAE: {:.4f}"
                .format(i, total_loss / count, best_mae, mae))
        torch.cuda.empty_cache()
Beispiel #22
0
    tiger_val_loader = get_val_dataloader(root_path=settings.TRAIN_PATH,
                                          num_workers=args.w,
                                          batch_size=8,
                                          shuffle=args.s)

    loss_function = CrossEntropyLabelSmooth(num_classes=35)

    # optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    # train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=settings.MILESTONES, gamma=0.2)   #learning rate decay
    # iter_per_epoch = len(tiger_training_loader)
    # warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    #原来的lr策略
    optimizer = make_optimizer(args, net)
    warmup_scheduler = WarmupMultiStepLR(
        optimizer,
        settings.MILESTONES,
        gamma=0.5,  #0.1, 0.5
        warmup_factor=1.0 / 3,
        warmup_iters=0,
        warmup_method="linear",
        last_epoch=-1,
    )

    #cycle lr
    # optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    # warmup_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer,
    #     T_max = 10,
Beispiel #23
0
'''###########################################'''
'''#### 1. define datasets                 ###'''
'''###########################################'''
dset = Set_Dataset(args)
'''###########################################'''
'''#### 2. define loss functions           ###'''
'''###########################################'''
loss = Set_Loss(args)
'''###########################################'''
'''#### 3. define SR model                 ###'''
'''###########################################'''
SR_model = models.SR_model_selector(args).model_return()
'''############################################'''
'''#### 4. define optimizer & scheduler     ###'''
'''############################################'''
SR_optimizer = utils.make_optimizer(args, SR_model)
SR_lrscheduler = utils.make_lrscheduler(args, SR_optimizer)
'''############################################'''
'''#### 5. do training or testing           ###'''
'''############################################'''
p = phases(args)
if args.phase == 'train':
    timer.tic()
    print('start training... %s' % (datetime.now()))

    for epoch in range(args.num_epoch):
        p.do('train', epoch, SR_model, loss, SR_optimizer, dset.tr_dataloader,
             dset.vl_dataloader, dset.te_dataloader)
        p.do('valid', epoch, SR_model, loss, SR_optimizer, dset.tr_dataloader,
             dset.vl_dataloader, dset.te_dataloader)
Beispiel #24
0
def main(config):
    svname = config.get('sv_name')
    if args.tag is not None:
        svname += '_' + args.tag
    config['sv_name'] = svname
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    utils.log(svname)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    n_way, n_shot = config['n_way'], config['n_shot']
    n_query = config['n_query']
    n_pseudo = config['n_pseudo']
    ep_per_batch = config['ep_per_batch']

    if config.get('test_batches') is not None:
        test_batches = config['test_batches']
    else:
        test_batches = config['train_batches']

    for s in ['train', 'val', 'tval']:
        if config.get(f"{s}_dataset_args") is not None:
            config[f"{s}_dataset_args"]['data_dir'] = os.path.join(os.getcwd(), os.pardir, 'data_root')

    # train
    train_dataset = CustomDataset(config['train_dataset'], save_dir=config.get('load_encoder'),
                                  **config['train_dataset_args'])

    if config['train_dataset_args']['split'] == 'helper':
        with open(os.path.join(save_path, 'train_helper_cls.pkl'), 'wb') as f:
            pkl.dump(train_dataset.dataset_classes, f)

    train_sampler = EpisodicSampler(train_dataset, config['train_batches'], n_way, n_shot, n_query,
                                    n_pseudo, episodes_per_batch=ep_per_batch)
    train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                                  num_workers=4, pin_memory=True)

    # tval
    if config.get('tval_dataset'):
        tval_dataset = CustomDataset(config['tval_dataset'],
                                     **config['tval_dataset_args'])

        tval_sampler = EpisodicSampler(tval_dataset, test_batches, n_way, n_shot, n_query,
                                       n_pseudo, episodes_per_batch=ep_per_batch)
        tval_loader = DataLoader(tval_dataset, batch_sampler=tval_sampler,
                                 num_workers=4, pin_memory=True)
    else:
        tval_loader = None

    # val
    val_dataset = CustomDataset(config['val_dataset'],
                                **config['val_dataset_args'])
    val_sampler = EpisodicSampler(val_dataset, test_batches, n_way, n_shot, n_query,
                                  n_pseudo, episodes_per_batch=ep_per_batch)
    val_loader = DataLoader(val_dataset, batch_sampler=val_sampler,
                            num_workers=4, pin_memory=True)


    #### Model and optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])
        if config.get('load_encoder'):
            encoder = models.load(torch.load(config['load_encoder'])).encoder
            model.encoder.load_state_dict(encoder.state_dict())
            if config.get('freeze_encoder'):
                for param in model.encoder.parameters():
                    param.requires_grad = False

    if config.get('_parallel'):
        model = nn.DataParallel(model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(
        model.parameters(),
        config['optimizer'], **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    aves_keys = ['tl', 'ta', 'tvl', 'tva', 'vl', 'va']
    trlog = dict()
    for k in aves_keys:
        trlog[k] = []

    for epoch in range(1, max_epoch + 1):
        timer_epoch.s()
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        if config.get('freeze_bn'):
            utils.freeze_bn(model)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        np.random.seed(epoch)

        for data in tqdm(train_loader, desc='train', leave=False):
            x_shot, x_query, x_pseudo = fs.split_shot_query(
                data.cuda(), n_way, n_shot, n_query, n_pseudo,
                ep_per_batch=ep_per_batch)
            label = fs.make_nk_label(n_way, n_query,
                                     ep_per_batch=ep_per_batch).cuda()

            logits = model(x_shot, x_query, x_pseudo)
            logits = logits.view(-1, n_way)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None; loss = None

            # eval
        model.eval()
        for name, loader, name_l, name_a in [
            ('tval', tval_loader, 'tvl', 'tva'),
            ('val', val_loader, 'vl', 'va')]:

            if (config.get('tval_dataset') is None) and name == 'tval':
                continue

            np.random.seed(0)
            for data in tqdm(loader, desc=name, leave=False):
                x_shot, x_query, x_pseudo = fs.split_shot_query(
                    data.cuda(), n_way, n_shot, n_query, n_pseudo,
                    ep_per_batch=ep_per_batch)
                label = fs.make_nk_label(n_way, n_query,
                                         ep_per_batch=ep_per_batch).cuda()

                with torch.no_grad():
                    logits = model(x_shot, x_query, x_pseudo)
                    logits = logits.view(-1, n_way)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves[name_l].add(loss.item())
                aves[name_a].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()
            trlog[k].append(aves[k])

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)
        utils.log('epoch {}, train {:.4f}|{:.4f}, tval {:.4f}|{:.4f}, '
                  'val {:.4f}|{:.4f}, {} {}/{}'.format(
            epoch, aves['tl'], aves['ta'], aves['tvl'], aves['tva'],
            aves['vl'], aves['va'], t_epoch, t_used, t_estimate))

        writer.add_scalars('loss', {
            'train': aves['tl'],
            'tval': aves['tvl'],
            'val': aves['vl'],
        }, epoch)
        writer.add_scalars('acc', {
            'train': aves['ta'],
            'tval': aves['tva'],
            'val': aves['va'],
        }, epoch)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,

            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),

            'training': training,
        }
        torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))
        torch.save(trlog, os.path.join(save_path, 'trlog.pth'))

        if (save_epoch is not None) and epoch % save_epoch == 0:
            torch.save(save_obj,
                       os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if aves['va'] > max_va:
            max_va = aves['va']
            torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))

        writer.flush()
Beispiel #25
0
def train(config):
    #### set the save and log path ####
    # svname = args.name
    # if svname is None:
    #     svname = config['train_dataset_type'] + '_' + config['model']
    #     # svname += '_' + config['model_args']['encoder']
    #     # if config['model_args']['classifier'] != 'linear-classifier':
    #     #     svname += '-' + config['model_args']['classifier']
    # if args.tag is not None:
    #     svname += '_' + args.tag
    # save_path = os.path.join('./save', svname)
    save_path = config['save_path']
    utils.set_save_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(config['save_path'], 'tensorboard'))
    yaml.dump(config, open(os.path.join(config['save_path'], 'classifier_config.yaml'), 'w'))

    #### make datasets ####
    # train
    train_folder = config['dataset_path'] + config['train_dataset_type'] + "/training/frames"
    test_folder = config['dataset_path'] + config['train_dataset_type'] + "/testing/frames"

    # Loading dataset
    train_dataset_args = config['train_dataset_args']
    test_dataset_args = config['test_dataset_args']
    train_dataset = VadDataset(train_folder, transforms.Compose([
        transforms.ToTensor(),
    ]), resize_height=train_dataset_args['h'], resize_width=train_dataset_args['w'], time_step=train_dataset_args['t_length'] - 1)

    test_dataset = VadDataset(test_folder, transforms.Compose([
        transforms.ToTensor(),
    ]), resize_height=test_dataset_args['h'], resize_width=test_dataset_args['w'], time_step=test_dataset_args['t_length'] - 1)

    train_dataloader = DataLoader(train_dataset, batch_size=train_dataset_args['batch_size'],
                                  shuffle=True, num_workers=train_dataset_args['num_workers'], drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=test_dataset_args['batch_size'],
                                 shuffle=False, num_workers=test_dataset_args['num_workers'], drop_last=False)

    # for test---- prepare labels
    labels = scipy.io.loadmat(config['label_path'])
    if config['test_dataset_type'] == 'shanghai':
        labels = np.expand_dims(labels, 0)
    videos = OrderedDict()
    videos_list = sorted(glob.glob(os.path.join(test_folder, '*')))
    labels_list = []
    label_length = 0
    psnr_list = {}
    for video in sorted(videos_list):
        video_name = video.split('/')[-1]
        videos[video_name] = {}
        videos[video_name]['path'] = video
        videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
        videos[video_name]['frame'].sort()
        videos[video_name]['length'] = len(videos[video_name]['frame'])
        labels_list = np.append(labels_list, labels[video_name][0][4:])
        label_length += videos[video_name]['length']
        psnr_list[video_name] = []

    # Model setting
    if config['generator'] == 'cycle_generator_convlstm':
        ngf = 64
        netG = 'resnet_6blocks'
        norm = 'instance'
        no_dropout = False
        init_type = 'normal'
        init_gain = 0.02
        gpu_ids = []
        model = define_G(train_dataset_args['c'], train_dataset_args['c'],
                             ngf, netG, norm, not no_dropout, init_type, init_gain, gpu_ids)
    elif config['generator'] == 'unet':
        # TODO
        model = UNet(n_channels=train_dataset_args['c']*(train_dataset_args['t_length']-1),
                         layer_nums=num_unet_layers, output_channel=train_dataset_args['c'])
    else:
        raise Exception('The generator is not implemented')

    # optimizer setting
    params_encoder = list(model.parameters())
    params_decoder = list(model.parameters())
    params = params_encoder + params_decoder
    optimizer, lr_scheduler = utils.make_optimizer(
        params, config['optimizer'], config['optimizer_args'])

    # set loss, different range with the source version, should change
    lam_int = 1.0 * 2
    lam_gd = 1.0 * 2
    # TODO here we use no flow loss
    # lam_op = 0  # 2.0
    # op_loss = Flow_Loss()
    
    adversarial_loss = Adversarial_Loss()
    # # TODO if use adv
    # lam_adv = 0.05
    # discriminate_loss = Discriminate_Loss()
    alpha = 1
    l_num = 2
    gd_loss = Gradient_Loss(alpha, train_dataset_args['c'])    
    int_loss = Intensity_Loss(l_num)

    # parallel if muti-gpus
    if torch.cuda.is_available():
        model.cuda()
    if config.get('_parallel'):
        model = nn.DataParallel(model)

    # Training
    utils.log('Start train')
    max_frame_AUC, max_roi_AUC = 0,0
    base_channel_num  = train_dataset_args['c'] * (train_dataset_args['t_length'] - 1)
    save_epoch = 5 if config['save_epoch'] is None else config['save_epoch']
    for epoch in range(config['epochs']):

        model.train()
        for j, imgs in enumerate(tqdm(train_dataloader, desc='train', leave=False)):
            imgs = imgs.cuda()
            # input = imgs[:, :-1, ].view(imgs.shape[0], -1, imgs.shape[-2], imgs.shape[-1])
            input = imgs[:, :-1, ]
            target = imgs[:, -1, ]
            outputs = model(input)
            optimizer.zero_grad()

            g_int_loss = int_loss(outputs, target)
            g_gd_loss = gd_loss(outputs, target)
            loss = lam_gd * g_gd_loss + lam_int * g_int_loss

            loss.backward()
            optimizer.step()
        lr_scheduler.step()

        utils.log('----------------------------------------')
        utils.log('Epoch:' + str(epoch + 1))
        utils.log('----------------------------------------')
        utils.log('Loss: Reconstruction {:.6f}'.format(loss.item()))

        # Testing
        utils.log('Evaluation of ' + config['test_dataset_type'])   

        # Save the model
        if epoch % save_epoch == 0 or epoch == config['epochs'] - 1:
            if not os.path.exists(save_path):
                os.makedirs(save_path) 
            if not os.path.exists(os.path.join(save_path, "models")):
                os.makedirs(os.path.join(save_path, "models")) 
            # TODO
            frame_AUC, roi_AUC = evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'], 
                is_visual=False, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./final"), labels_dict=labels) 
            
            torch.save(model.state_dict(), os.path.join(save_path, 'models/model-epoch-{}.pth'.format(epoch)))
        else:
            frame_AUC, roi_AUC = evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w']) 

        utils.log('The result of ' + config['test_dataset_type'])
        utils.log("AUC: {}%, roi AUC: {}%".format(frame_AUC*100, roi_AUC*100))

        if frame_AUC > max_frame_AUC:
            max_frame_AUC = frame_AUC
            # TODO
            torch.save(model.state_dict(), os.path.join(save_path, 'models/max-frame_auc-model.pth'))
            evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'], 
                is_visual=True, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./frame_best"), labels_dict=labels) 
        if roi_AUC > max_roi_AUC:
            max_roi_AUC = roi_AUC
            torch.save(model.state_dict(), os.path.join(save_path, 'models/max-roi_auc-model.pth'))
            evaluate(test_dataloader, model, labels_list, videos, int_loss, config['test_dataset_type'], test_bboxes=config['test_bboxes'],
                frame_height = train_dataset_args['h'], frame_width=train_dataset_args['w'], 
                is_visual=True, mask_labels_path = config['mask_labels_path'], save_path = os.path.join(save_path, "./roi_best"), labels_dict=labels) 

        utils.log('----------------------------------------')

    utils.log('Training is finished')
    utils.log('max_frame_AUC: {}, max_roi_AUC: {}'.format(max_frame_AUC, max_roi_AUC))
Beispiel #26
0
def main(cfg):
    global best_loss
    best_loss = 100.

    # make dirs
    for dirs in [cfg["MODELS_DIR"], cfg["OUTPUT_DIR"], cfg["LOGS_DIR"]]:
        if not os.path.exists(dirs):
            os.makedirs(dirs)

    # create dataset
    train_ds = RSNAHemorrhageDS3d(cfg, mode="train")
    valid_ds = RSNAHemorrhageDS3d(cfg, mode="valid")
    test_ds = RSNAHemorrhageDS3d(cfg, mode="test")

    # create model
    extra_model_args = {
        "attention": cfg["ATTENTION"],
        "dropout": cfg["DROPOUT"],
        "num_layers": cfg["NUM_LAYERS"],
        "recur_type": cfg["RECUR_TYPE"],
        "num_heads": cfg["NUM_HEADS"],
        "dim_ffw": cfg["DIM_FFW"]
    }
    if cfg["MODEL_NAME"].startswith("tf_efficient"):
        model = GenericEfficientNet3d(cfg["MODEL_NAME"],
                                      input_channels=cfg["NUM_INP_CHAN"],
                                      num_classes=cfg["NUM_CLASSES"],
                                      **extra_model_args)
    elif "res" in cfg["MODEL_NAME"]:
        model = ResNet3d(cfg["MODEL_NAME"],
                         input_channels=cfg["NUM_INP_CHAN"],
                         num_classes=cfg["NUM_CLASSES"],
                         **extra_model_args)
    # print(model)

    # define loss function & optimizer
    class_weight = torch.FloatTensor(cfg["BCE_W"])
    # criterion = nn.BCEWithLogitsLoss(weight=class_weight)
    criterion = nn.BCEWithLogitsLoss(weight=class_weight, reduction='none')
    kd_criterion = KnowledgeDistillationLoss(temperature=cfg["TAU"])
    valid_criterion = nn.BCEWithLogitsLoss(weight=class_weight,
                                           reduction='none')
    optimizer = make_optimizer(cfg, model)

    if cfg["CUDA"]:
        model = model.cuda()
        criterion = criterion.cuda()
        kd_criterion.cuda()
        valid_criterion = valid_criterion.cuda()

    if args.dtype == 'float16':
        if args.opt_level == "O1":
            keep_batchnorm_fp32 = None
        else:
            keep_batchnorm_fp32 = True
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=keep_batchnorm_fp32)

    start_epoch = 0
    # optionally resume from a checkpoint
    if cfg["RESUME"]:
        if os.path.isfile(cfg["RESUME"]):
            logger.info("=> Loading checkpoint '{}'".format(cfg["RESUME"]))
            checkpoint = torch.load(cfg["RESUME"], "cpu")
            load_state_dict(checkpoint.pop('state_dict'), model)
            if not args.finetune:
                start_epoch = checkpoint['epoch']
                optimizer.load_state_dict(checkpoint.pop('optimizer'))
                best_loss = checkpoint['best_loss']
            logger.info("=> Loaded checkpoint '{}' (epoch {})".format(
                cfg["RESUME"], checkpoint['epoch']))
        else:
            logger.info("=> No checkpoint found at '{}'".format(cfg["RESUME"]))

    if cfg["MULTI_GPU"]:
        model = nn.DataParallel(model)

    # create data loaders & lr scheduler
    train_loader = DataLoader(train_ds,
                              cfg["BATCH_SIZE"],
                              pin_memory=False,
                              shuffle=True,
                              drop_last=False,
                              num_workers=cfg['NUM_WORKERS'])
    valid_loader = DataLoader(valid_ds,
                              pin_memory=False,
                              shuffle=False,
                              drop_last=False,
                              num_workers=cfg['NUM_WORKERS'])
    test_loader = DataLoader(test_ds,
                             pin_memory=False,
                             collate_fn=test_collate_fn,
                             shuffle=False,
                             drop_last=False,
                             num_workers=cfg['NUM_WORKERS'])
    scheduler = WarmupCyclicalLR("cos",
                                 cfg["BASE_LR"],
                                 cfg["EPOCHS"],
                                 iters_per_epoch=len(train_loader),
                                 warmup_epochs=cfg["WARMUP_EPOCHS"])
    logger.info("Using {} lr scheduler\n".format(scheduler.mode))

    if args.eval:
        _, prob = validate(cfg, valid_loader, model, valid_criterion)
        imgids = pd.read_csv(cfg["DATA_DIR"] + "valid_{}_df_fold{}.csv" \
            .format(cfg["SPLIT"], cfg["FOLD"]))["image"]
        save_df = pd.concat([imgids, pd.DataFrame(prob.numpy())], 1)
        save_df.columns = [
            "image", "any", "intraparenchymal", "intraventricular",
            "subarachnoid", "subdural", "epidural"
        ]
        save_df.to_csv(os.path.join(cfg["OUTPUT_DIR"],
                                    "val_" + cfg["SESS_NAME"] + '.csv'),
                       index=False)
        return

    if args.eval_test:
        if not os.path.exists(cfg["OUTPUT_DIR"]):
            os.makedirs(cfg["OUTPUT_DIR"])
        submit_fpath = os.path.join(cfg["OUTPUT_DIR"],
                                    "test_" + cfg["SESS_NAME"] + '.csv')
        submit_df = test(cfg, test_loader, model)
        submit_df.to_csv(submit_fpath, index=False)
        return

    for epoch in range(start_epoch, cfg["EPOCHS"]):
        logger.info("Epoch {}\n".format(str(epoch + 1)))
        random.seed(epoch)
        torch.manual_seed(epoch)
        # train for one epoch
        train(cfg, train_loader, model, criterion, kd_criterion, optimizer,
              scheduler, epoch)
        # evaluate
        loss, _ = validate(cfg, valid_loader, model, valid_criterion)
        # remember best loss and save checkpoint
        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        if cfg["MULTI_GPU"]:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': cfg["MODEL_NAME"],
                    'state_dict': model.module.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                root=cfg['MODELS_DIR'],
                filename=f"{cfg['SESS_NAME']}.pth")
        else:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': cfg["MODEL_NAME"],
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                root=cfg['MODELS_DIR'],
                filename=f"{cfg['SESS_NAME']}.pth")
Beispiel #27
0
def main(config):
    # Environment setup
    save_dir = config['save_dir']
    utils.ensure_path(save_dir)
    with open(osp.join(save_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f, sort_keys=False)
    global log, writer
    logger = set_logger(osp.join(save_dir, 'log.txt'))
    log = logger.info
    writer = SummaryWriter(osp.join(save_dir, 'tensorboard'))

    os.environ['WANDB_NAME'] = config['exp_name']
    os.environ['WANDB_DIR'] = config['save_dir']
    if not config.get('wandb_upload', False):
        os.environ['WANDB_MODE'] = 'dryrun'
    t = config['wandb']
    os.environ['WANDB_API_KEY'] = t['api_key']
    wandb.init(project=t['project'], entity=t['entity'], config=config)

    log('logging init done.')
    log(f'wandb id: {wandb.run.id}')

    # Dataset, model and optimizer
    train_dataset = datasets.make((config['train_dataset']))
    test_dataset = datasets.make((config['test_dataset']))

    model = models.make(config['model'], args=None).cuda()
    log(f'model #params: {utils.compute_num_params(model)}')

    n_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
    if n_gpus > 1:
        model = nn.DataParallel(model)

    optimizer = utils.make_optimizer(model.parameters(), config['optimizer'])

    train_loader = DataLoader(train_dataset, config['batch_size'], shuffle=True,
                              num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_dataset, config['batch_size'],
                             num_workers=8, pin_memory=True)

    # Ready for training
    max_epoch = config['max_epoch']
    n_milestones = config.get('n_milestones', 1)
    milestone_epoch = max_epoch // n_milestones
    min_test_loss = 1e18

    sample_batch_train = sample_data_batch(train_dataset).cuda()
    sample_batch_test = sample_data_batch(test_dataset).cuda()

    epoch_timer = utils.EpochTimer(max_epoch)
    for epoch in range(1, max_epoch + 1):
        log_text = f'epoch {epoch}'

        # Train
        model.train()

        adjust_lr(optimizer, epoch, max_epoch, config)
        log_temp_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        ave_scalars = {k: utils.Averager() for k in ['loss']}

        pbar = tqdm(train_loader, desc='train', leave=False)
        for data in pbar:
            data = data.cuda()
            t = train_step(model, data, data, optimizer)
            for k, v in t.items():
                ave_scalars[k].add(v, len(data))
            pbar.set_description(desc=f"train loss:{t['loss']:.4f}")

        log_text += ', train:'
        for k, v in ave_scalars.items():
            v = v.item()
            log_text += f' {k}={v:.4f}'
            log_temp_scalar('train/' + k, v, epoch)

        # Test
        model.eval()

        ave_scalars = {k: utils.Averager() for k in ['loss']}

        pbar = tqdm(test_loader, desc='test', leave=False)
        for data in pbar:
            data = data.cuda()
            t = eval_step(model, data, data)
            for k, v in t.items():
                ave_scalars[k].add(v, len(data))
            pbar.set_description(desc=f"test loss:{t['loss']:.4f}")

        log_text += ', test:'
        for k, v in ave_scalars.items():
            v = v.item()
            log_text += f' {k}={v:.4f}'
            log_temp_scalar('test/' + k, v, epoch)

        test_loss = ave_scalars['loss'].item()

        if epoch % milestone_epoch == 0:
            with torch.no_grad():
                pred = model(sample_batch_train).clamp(0, 1)
                video_batch = torch.cat([sample_batch_train, pred], dim=0)
                log_temp_videos('train/videos', video_batch, epoch)
                img_batch = video_batch[:, :, 3, :, :]
                log_temp_images('train/images', img_batch, epoch)

                pred = model(sample_batch_test).clamp(0, 1)
                video_batch = torch.cat([sample_batch_test, pred], dim=0)
                log_temp_videos('test/videos', video_batch, epoch)
                img_batch = video_batch[:, :, 3, :, :]
                log_temp_images('test/images', img_batch, epoch)

        # Summary and save
        log_text += ', {} {}/{}'.format(*epoch_timer.step())
        log(log_text)

        model_ = model.module if n_gpus > 1 else model
        model_spec = config['model']
        model_spec['sd'] = model_.state_dict()
        optimizer_spec = config['optimizer']
        optimizer_spec['sd'] = optimizer.state_dict()
        pth_file = {
            'model': model_spec,
            'optimizer': optimizer_spec,
            'epoch': epoch,
        }

        if test_loss < min_test_loss:
            min_test_loss = test_loss
            wandb.run.summary['min_test_loss'] = min_test_loss
            torch.save(pth_file, osp.join(save_dir, 'min-test-loss.pth'))

        torch.save(pth_file, osp.join(save_dir, 'epoch-last.pth'))

        writer.flush()
    def train(self,
              train_tfrecord,
              loss_type,
              vgg_fmaps=None,
              vgg_weights=None,
              validation_tfrecord=None,
              training_steps=6000,
              batch_size=2,
              initial_lr=10**(-4),
              decay_steps=2000,
              decay_rate=0.5,
              do_online_augmentation=True,
              log_dir='',
              model_dir='models'):
        """
        method used to train a network, specify a loss, logging and saving the trained model
        :param train_tfrecord: a tfrecord for training data (see readers.py for details on the expected format)
        :param loss_type: a string that can be
                     'UNET': ['bce', 'bce-topo']
                     'iUNET': ['i-bce', 'i-bce-equal', 'i-bce-topo', 'i-bce-topo-equal']
                     'SHN': ['s-bce', 's-bce-topo']

        :param vgg_fmaps: a list of the names of the vgg feature maps to be used for the perceptual loss
        :param vgg_weights: a list of weights, each controlling the importance of a feature map of the
                            perceptual loss in the total loss
        :param validation_tfrecord: (optional) a tfrecord to keep track of during training
        :param training_steps: total number of training steps
        :param batch_size: batch size for the optimizer
        :param initial_lr: starting learning rate
        :param decay_steps: learning rate decay steps
        :param decay_rate: rate of decay of the lr (see make_optimizer, make_learning_rate_scheduler utils.py)
        :param do_online_augmentation: if True performs online data augmentation using
                                       the default settings (see readers.py)
        :param log_dir: path to dir where the logs will be stored
        :param model_dir: path to dir where the models will be stored
        :return:
        """

        if 'topo' in loss_type and vgg_weights is None and vgg_fmaps is None:
            # default settings for the topological loss function
            vgg_fmaps = [
                'vgg_19/conv1/conv1_2', 'vgg_19/conv2/conv2_2',
                'vgg_19/conv3/conv3_4'
            ]
            vgg_weights = [0.01, 0.001, 0.0001]

        train_examples = count_records_in_tfrecord(train_tfrecord)
        epoch_step = math.floor(train_examples / batch_size)

        # define the loss: sets self.loss and self.loss_summaries
        self._loss_def(loss_type=loss_type,
                       vgg_weights=vgg_weights,
                       vgg_fmaps=vgg_fmaps)

        # generate a tag for logging purposes and saving directory naming
        self._set_tag_and_create_model_dir(vgg_fmaps, vgg_weights, model_dir)

        # using the default optimizer settings
        with tf.name_scope('optimizer'):
            self.train_op, learning_rate_summary, grads_and_vars_summary = make_optimizer(
                self.loss,
                self.variables,
                lr_start=initial_lr,
                lr_scheduler_type='inverse_time',
                decay_steps=decay_steps,
                decay_rate=decay_rate,
                name='Adam_optimizer')

        # summaries
        input_summary = tf.summary.image('input MIP ', self.x, max_outputs=1)
        ground_truth_summary = tf.summary.image('ground truth',
                                                self.y,
                                                max_outputs=1)
        output_summary = tf.summary.image('output', self.output, max_outputs=1)

        train_summary = tf.summary.merge(
            [learning_rate_summary, output_summary] + self.loss_summaries)
        valid_summary = tf.summary.merge(
            [input_summary, output_summary, ground_truth_summary])

        # readers and saver
        train_reader = Reader(train_tfrecord,
                              image_size=416,
                              channels=1,
                              batch_size=batch_size,
                              do_online_augmentation=do_online_augmentation,
                              do_shuffle=True,
                              name='train_reader')
        x_train, y_train = train_reader.feed()

        # if there is a validation set
        validation_step, validation_examples = -1, 0  # just to supress warnings
        if validation_tfrecord is not None:
            validation_examples, validation_step, best_quality = count_records_in_tfrecord(
                validation_tfrecord), epoch_step, 0
            validation_reader = Reader(validation_tfrecord,
                                       image_size=416,
                                       channels=1,
                                       batch_size=1,
                                       do_shuffle=False,
                                       do_online_augmentation=False,
                                       name='validation_reader')
            x_test, y_test = validation_reader.feed()

        saver = tf.train.Saver(max_to_keep=1, var_list=self.variables)

        with tf.Session() as sess:
            train_writer = tf.summary.FileWriter(
                os.path.join(log_dir, 'logs', self.name, self.tag), sess.graph)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            sess.run(tf.global_variables_initializer())

            # restore vgg19 for the topological loss term
            # requires pretrained vgg_19 to be inside the path
            if 'topo' in loss_type:
                assert(os.path.isfile('vgg_19.ckpt')), 'vgg_19.ckpt must be in the path for ' \
                                                       'training with loss_type=[{}] please ' \
                                                       'download from {}'.format(self.loss_type, self.vgg19_link)

                vgg_restore_list = [
                    v for v in tf.global_variables()
                    if 'vgg_19' in v.name and 'Adam' not in v.name
                ]
                restorer_vgg = tf.train.Saver(var_list=vgg_restore_list)
                restorer_vgg.restore(sess, 'vgg_19.ckpt')

            epoch, costs, losses_train = 0, [], []
            for i in range(training_steps):
                x_value, y_value = sess.run([x_train, y_train])

                # below the two augmentation functions are implemented outside tensorflow
                if do_online_augmentation:
                    x_value, y_value = do_deformation_off_graph(
                        x_value, y_value, deform_prob=0.5)
                    x_value = do_eraser_off_graph(x_value,
                                                  eraser_prob=0.5,
                                                  boxes_max=50,
                                                  boxes_min=150)

                train_feed = {self.x: x_value, self.y: y_value}
                loss_batch, summary_train, _ = sess.run(
                    [self.loss, train_summary, self.train_op],
                    feed_dict=train_feed)
                train_writer.add_summary(summary_train, i)

                epoch, losses_train, loss_mean, train_writer = track_mean_value_per_epoch(
                    i,
                    loss_batch,
                    epoch_step,
                    epoch,
                    losses_train,
                    train_writer,
                    tag='loss-per-epoch')
                # validation
                if ((i % validation_step == 0) or
                    (i == training_steps -
                     1)) and i > 0 and validation_tfrecord is not None:
                    sigmoided_out_list = []
                    y_test_value_list = []
                    for test_step in range(validation_examples):
                        x_test_value, y_test_value = sess.run([x_test, y_test])
                        test_feed = {
                            self.x: x_test_value,
                            self.y: y_test_value
                        }
                        sigmoided_out, summary_val = sess.run(
                            [self.output, valid_summary], feed_dict=test_feed)
                        train_writer.add_summary(summary_val, i + test_step)
                        sigmoided_out_list.append(sigmoided_out)
                        y_test_value_list.append(y_test_value)

                    # run metric (computes means across the validation set)
                    correctness, completeness, quality, _, _, _ = correctness_completeness_quality(
                        y_test_value_list, sigmoided_out_list, threshold=0.5)
                    new_quality = quality
                    if best_quality < new_quality:
                        diff = new_quality - best_quality
                        print('EPOCH:', epoch, 'completness:', completeness,
                              'correctness:', correctness, 'quality:', quality,
                              'previous quality:', best_quality,
                              'NEW_BEST with difference:', diff)
                        best_quality = new_quality
                        save_path = saver.save(
                            sess, model_dir + '/' + self.name + '_' + str(i) +
                            ".ckpt")
                        print("Model saved in path: %s" % save_path)
                    else:
                        print('EPOCH:', epoch, 'completness:', completeness,
                              'correctness:', correctness, 'quality:', quality)
                    summary_metrics = tf.Summary()
                    summary_metrics.value.add(tag='completness',
                                              simple_value=completeness)
                    summary_metrics.value.add(tag='correctness',
                                              simple_value=correctness)
                    summary_metrics.value.add(tag='quality',
                                              simple_value=quality)
                    train_writer.add_summary(summary_metrics, i)

                if i == (training_steps - 1):
                    save_path = saver.save(
                        sess,
                        os.path.join(self.model_save_dir,
                                     self.name + '_' + str(i) + ".ckpt"))
                    print("Final Model saved in path: %s" % save_path)

            coord.request_stop()
            coord.join(threads)
        train_writer.close()
def main(config):
    svname = args.name
    if svname is None:
        svname = 'classifier_{}'.format(config['train_dataset'])
        svname += '_' + config['model_args']['encoder']
        clsfr = config['model_args']['classifier']
        if clsfr != 'linear-classifier':
            svname += '-' + clsfr
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join('./save', svname)
    utils.ensure_path(save_path)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    #### Dataset ####

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    augmentations = [
        transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.RandomRotation(35),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        transforms.Compose([
            transforms.RandomRotation(35),
            transforms.RandomResizedCrop(size=(80, 80),
                                         scale=(0.08, 1.0),
                                         ratio=(0.75, 1.3333)),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    ]
    train_dataset.transform = augmentations[int(config['_a'])]
    print(train_dataset.transform)
    print("_a", config['_a'])
    input("Continue with these augmentations?")

    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True)
    utils.log('train dataset: {} (x{}), {}'.format(train_dataset[0][0].shape,
                                                   len(train_dataset),
                                                   train_dataset.n_classes))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = datasets.make(config['val_dataset'],
                                    **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=0,
                                pin_memory=True)
        utils.log('val dataset: {} (x{}), {}'.format(val_dataset[0][0].shape,
                                                     len(val_dataset),
                                                     val_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    else:
        eval_val = False

    # few-shot eval
    if config.get('fs_dataset'):
        ef_epoch = config.get('eval_fs_epoch')
        if ef_epoch is None:
            ef_epoch = 5
        eval_fs = True

        fs_dataset = datasets.make(config['fs_dataset'],
                                   **config['fs_dataset_args'])
        utils.log('fs dataset: {} (x{}), {}'.format(fs_dataset[0][0].shape,
                                                    len(fs_dataset),
                                                    fs_dataset.n_classes))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(fs_dataset, 'fs_dataset', writer)

        n_way = 5
        n_query = 15
        n_shots = [1, 5]
        fs_loaders = []
        for n_shot in n_shots:
            fs_sampler = CategoriesSampler(fs_dataset.label,
                                           200,
                                           n_way,
                                           n_shot + n_query,
                                           ep_per_batch=4)
            fs_loader = DataLoader(fs_dataset,
                                   batch_sampler=fs_sampler,
                                   num_workers=0,
                                   pin_memory=True)
            fs_loaders.append(fs_loader)
    else:
        eval_fs = False

    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):
        if epoch == max_epoch + 1:
            if not config.get('epoch_ex'):
                break
            train_dataset.transform = train_dataset.default_transform
            print(train_dataset.transform)
            train_loader = DataLoader(train_dataset,
                                      config['batch_size'],
                                      shuffle=True,
                                      num_workers=0,
                                      pin_memory=True)

        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va']
        if eval_fs:
            for n_shot in n_shots:
                aves_keys += ['fsa-' + str(n_shot)]
        aves = {k: utils.Averager() for k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, label in tqdm(train_loader, desc='train', leave=False):
            # for data, label in train_loader:
            data, label = data.cuda(), label.cuda()
            logits = model(data)
            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # eval
        if eval_val:
            model.eval()
            for data, label in tqdm(val_loader, desc='val', leave=False):
                data, label = data.cuda(), label.cuda()
                with torch.no_grad():
                    logits = model(data)
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for i, n_shot in enumerate(n_shots):
                np.random.seed(0)
                for data, _ in tqdm(fs_loaders[i],
                                    desc='fs-' + str(n_shot),
                                    leave=False):
                    x_shot, x_query = fs.split_shot_query(data.cuda(),
                                                          n_way,
                                                          n_shot,
                                                          n_query,
                                                          ep_per_batch=4)
                    label = fs.make_nk_label(n_way, n_query,
                                             ep_per_batch=4).cuda()
                    with torch.no_grad():
                        logits = fs_model(x_shot, x_query).view(-1, n_way)
                        acc = utils.compute_acc(logits, label)
                    aves['fsa-' + str(n_shot)].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}'.format(aves['vl'], aves['va'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            log_str += ', fs'
            for n_shot in n_shots:
                key = 'fsa-' + str(n_shot)
                log_str += ' {}: {:.4f}'.format(n_shot, aves[key])
                writer.add_scalars('acc', {key: aves[key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()
Beispiel #30
0
def main(config):
    svname = args.name
    if svname is None:
        svname = 'moco_{}'.format(config['train_dataset'])
        svname += '_' + config['model_args']['encoder']
        out_dim = config['model_args']['encoder_args']['out_dim']
        svname += '-out_dim' + str(out_dim)
    svname += '-seed' + str(args.seed)
    if args.tag is not None:
        svname += '_' + args.tag
    save_path = os.path.join(args.save_dir, svname)
    utils.ensure_path(save_path, remove=False)
    utils.set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))

    yaml.dump(config, open(os.path.join(save_path, 'config.yaml'), 'w'))

    random_state = np.random.RandomState(args.seed)
    print('seed:', args.seed)

    logger = utils.Logger(file_name=os.path.join(save_path, "log_sdout.txt"),
                          file_mode="a+",
                          should_flush=True)

    #### Dataset ####

    # train
    train_dataset = datasets.make(config['train_dataset'],
                                  **config['train_dataset_args'])
    train_loader = DataLoader(train_dataset,
                              config['batch_size'],
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True,
                              drop_last=True)
    utils.log('train dataset: {} (x{})'.format(train_dataset[0][0][0].shape,
                                               len(train_dataset)))
    if config.get('visualize_datasets'):
        utils.visualize_dataset(train_dataset, 'train_dataset', writer)

    # val
    if config.get('val_dataset'):
        eval_val = True
        val_dataset = datasets.make(config['val_dataset'],
                                    **config['val_dataset_args'])
        val_loader = DataLoader(val_dataset,
                                config['batch_size'],
                                num_workers=8,
                                pin_memory=True,
                                drop_last=True)
        utils.log('val dataset: {} (x{})'.format(val_dataset[0][0][0].shape,
                                                 len(val_dataset)))
        if config.get('visualize_datasets'):
            utils.visualize_dataset(val_dataset, 'val_dataset', writer)
    else:
        eval_val = False

    # few-shot eval
    if config.get('eval_fs'):
        ef_epoch = config.get('eval_fs_epoch')
        if ef_epoch is None:
            ef_epoch = 5
        eval_fs = True
        n_way = 2
        n_query = 1
        n_shot = 6

        if config.get('ep_per_batch') is not None:
            ep_per_batch = config['ep_per_batch']
        else:
            ep_per_batch = 1

        # tvals
        fs_loaders = {}
        tval_name_ntasks_dict = {
            'tval': 2000,
            'tval_ff': 600,
            'tval_bd': 480,
            'tval_hd_comb': 400,
            'tval_hd_novel': 320
        }  # numbers depend on dataset
        for tval_type in tval_name_ntasks_dict.keys():
            if config.get('{}_dataset'.format(tval_type)):
                tval_dataset = datasets.make(
                    config['{}_dataset'.format(tval_type)],
                    **config['{}_dataset_args'.format(tval_type)])
                utils.log('{} dataset: {} (x{})'.format(
                    tval_type, tval_dataset[0][0][0].shape, len(tval_dataset)))
                if config.get('visualize_datasets'):
                    utils.visualize_dataset(tval_dataset, 'tval_ff_dataset',
                                            writer)
                tval_sampler = BongardSampler(
                    tval_dataset.n_tasks,
                    n_batch=tval_name_ntasks_dict[tval_type] // ep_per_batch,
                    ep_per_batch=ep_per_batch,
                    seed=random_state.randint(2**31))
                tval_loader = DataLoader(tval_dataset,
                                         batch_sampler=tval_sampler,
                                         num_workers=8,
                                         pin_memory=True)
                fs_loaders.update({tval_type: tval_loader})
            else:
                fs_loaders.update({tval_type: None})

    else:
        eval_fs = False

    ########

    #### Model and Optimizer ####

    if config.get('load'):
        model_sv = torch.load(config['load'])
        model = models.load(model_sv)
    else:
        model = models.make(config['model'], **config['model_args'])

    if eval_fs:
        fs_model = models.make('meta-baseline', encoder=None)
        fs_model.encoder = model.encoder

    if config.get('_parallel'):
        model = nn.DataParallel(model)
        if eval_fs:
            fs_model = nn.DataParallel(fs_model)

    utils.log('num params: {}'.format(utils.compute_n_params(model)))

    optimizer, lr_scheduler = utils.make_optimizer(model.parameters(),
                                                   config['optimizer'],
                                                   **config['optimizer_args'])

    ########

    max_epoch = config['max_epoch']
    save_epoch = config.get('save_epoch')
    max_va = 0.
    timer_used = utils.Timer()
    timer_epoch = utils.Timer()

    for epoch in range(1, max_epoch + 1 + 1):

        timer_epoch.s()
        aves_keys = ['tl', 'ta', 'vl', 'va', 'tvl', 'tva']
        if eval_fs:
            for k, v in fs_loaders.items():
                if v is not None:
                    aves_keys += ['fsa' + k.split('tval')[-1]]
        aves = {ave_k: utils.Averager() for ave_k in aves_keys}

        # train
        model.train()
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        for data, _ in tqdm(train_loader, desc='train', leave=False):
            logits, label = model(im_q=data[0].cuda(), im_k=data[1].cuda())

            loss = F.cross_entropy(logits, label)
            acc = utils.compute_acc(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            aves['tl'].add(loss.item())
            aves['ta'].add(acc)

            logits = None
            loss = None

        # val
        if eval_val:
            model.eval()
            for data, _ in tqdm(val_loader, desc='val', leave=False):
                with torch.no_grad():
                    logits, label = model(im_q=data[0].cuda(),
                                          im_k=data[1].cuda())
                    loss = F.cross_entropy(logits, label)
                    acc = utils.compute_acc(logits, label)

                aves['vl'].add(loss.item())
                aves['va'].add(acc)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            fs_model.eval()
            for k, v in fs_loaders.items():
                if v is not None:
                    ave_key = 'fsa' + k.split('tval')[-1]
                    np.random.seed(0)
                    for data, _ in tqdm(v, desc=ave_key, leave=False):
                        x_shot, x_query = fs.split_shot_query(
                            data[0].cuda(),
                            n_way,
                            n_shot,
                            n_query,
                            ep_per_batch=ep_per_batch)
                        label_query = fs.make_nk_label(
                            n_way, n_query, ep_per_batch=ep_per_batch).cuda()
                        with torch.no_grad():
                            logits = fs_model(x_shot, x_query).view(-1, n_way)
                            acc = utils.compute_acc(logits, label_query)
                        aves[ave_key].add(acc)

        # post
        if lr_scheduler is not None:
            lr_scheduler.step()

        for k, v in aves.items():
            aves[k] = v.item()

        t_epoch = utils.time_str(timer_epoch.t())
        t_used = utils.time_str(timer_used.t())
        t_estimate = utils.time_str(timer_used.t() / epoch * max_epoch)

        if epoch <= max_epoch:
            epoch_str = str(epoch)
        else:
            epoch_str = 'ex'
        log_str = 'epoch {}, train {:.4f}|{:.4f}'.format(
            epoch_str, aves['tl'], aves['ta'])
        writer.add_scalars('loss', {'train': aves['tl']}, epoch)
        writer.add_scalars('acc', {'train': aves['ta']}, epoch)

        if eval_val:
            log_str += ', val {:.4f}|{:.4f}, tval {:.4f}|{:.4f}'.format(
                aves['vl'], aves['va'], aves['tvl'], aves['tva'])
            writer.add_scalars('loss', {'val': aves['vl']}, epoch)
            writer.add_scalars('loss', {'tval': aves['tvl']}, epoch)
            writer.add_scalars('acc', {'val': aves['va']}, epoch)
            writer.add_scalars('acc', {'tval': aves['tva']}, epoch)

        if eval_fs and (epoch % ef_epoch == 0 or epoch == max_epoch + 1):
            log_str += ', fs'
            for ave_key in aves_keys:
                if 'fsa' in ave_key:
                    log_str += ' {}: {:.4f}'.format(ave_key, aves[ave_key])
                    writer.add_scalars('acc', {ave_key: aves[ave_key]}, epoch)

        if epoch <= max_epoch:
            log_str += ', {} {}/{}'.format(t_epoch, t_used, t_estimate)
        else:
            log_str += ', {}'.format(t_epoch)
        utils.log(log_str)

        if config.get('_parallel'):
            model_ = model.module
        else:
            model_ = model

        training = {
            'epoch': epoch,
            'optimizer': config['optimizer'],
            'optimizer_args': config['optimizer_args'],
            'optimizer_sd': optimizer.state_dict(),
        }
        save_obj = {
            'file': __file__,
            'config': config,
            'model': config['model'],
            'model_args': config['model_args'],
            'model_sd': model_.state_dict(),
            'training': training,
        }
        if epoch <= max_epoch:
            torch.save(save_obj, os.path.join(save_path, 'epoch-last.pth'))

            if (save_epoch is not None) and epoch % save_epoch == 0:
                torch.save(
                    save_obj,
                    os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

            if aves['va'] > max_va:
                max_va = aves['va']
                torch.save(save_obj, os.path.join(save_path, 'max-va.pth'))
        else:
            torch.save(save_obj, os.path.join(save_path, 'epoch-ex.pth'))

        writer.flush()

    print('finished training!')
    logger.close()