Пример #1
0
def run(config_file):
    config = load_config(config_file)
    #set up the environment flags for working with the KAGGLE GPU OR COLAB_GPU
    if 'COLAB_GPU' in os.environ:
        config.work_dir = '/content/drive/My Drive/kaggle_cloud/' + config.work_dir
    elif 'KAGGLE_WORKING_DIR' in os.environ:
        config.work_dir = '/kaggle/working/' + config.work_dir
    print('working directory:', config.work_dir)

    #save the configuration to the working dir
    if not os.path.exists(config.work_dir):
        os.makedirs(config.work_dir, exist_ok=True)
    save_config(config, config.work_dir + '/config.yml')

    #Enter the GPUS you have,
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

    all_transforms = {}
    all_transforms['train'] = get_transforms(config.transforms.train)
    #our dataset has an explicit validation folder, use that later.
    all_transforms['valid'] = get_transforms(config.transforms.test)

    print("before rajat config", config.data.height, config.data.width)
    #fetch the dataloaders we need
    dataloaders = {
        phase: make_loader(data_folder=config.data.train_dir,
                           df_path=config.data.train_df_path,
                           phase=phase,
                           img_size=(config.data.height, config.data.width),
                           batch_size=config.train.batch_size,
                           num_workers=config.num_workers,
                           idx_fold=config.data.params.idx_fold,
                           transforms=all_transforms[phase],
                           num_classes=config.data.num_classes,
                           pseudo_label_path=config.train.pseudo_label_path,
                           debug=config.debug)
        for phase in ['train', 'valid']
    }

    #creating the segmentation model with pre-trained encoder
    '''
    dumping the parameters for smp library
    encoder_name: str = "resnet34",
    encoder_depth: int = 5,
    encoder_weights: str = "imagenet",
    decoder_use_batchnorm: bool = True,
    decoder_channels: List[int] = (256, 128, 64, 32, 16),
    decoder_attention_type: Optional[str] = None,
    in_channels: int = 3,
    classes: int = 1,
    activation: Optional[Union[str, callable]] = None,
    aux_params: Optional[dict] = None,
    '''
    model = getattr(smp, config.model.arch)(
        encoder_name=config.model.encoder,
        encoder_weights=config.model.pretrained,
        classes=config.data.num_classes,
        activation=None,
    )

    #fetch the loss
    criterion = get_loss(config)
    params = [
        {
            'params': model.decoder.parameters(),
            'lr': config.optimizer.params.decoder_lr
        },
        {
            'params': model.encoder.parameters(),
            'lr': config.optimizer.params.encoder_lr
        },
    ]
    optimizer = get_optimizer(params, config)
    scheduler = get_scheduler(optimizer, config)
    '''
    dumping the catalyst supervised runner
    https://github.com/catalyst-team/catalyst/blob/master/catalyst/dl/runner/supervised.py

    model (Model): Torch model object
    device (Device): Torch device
    input_key (str): Key in batch dict mapping for model input
    output_key (str): Key in output dict model output
        will be stored under
    input_target_key (str): Key in batch dict mapping for target
    '''

    runner = SupervisedRunner(model=model, device=get_device())

    #@pavel,srk,rajat,vladimir,pudae check the IOU and the Dice Callbacks

    callbacks = [DiceCallback(), IouCallback()]

    #adding patience
    if config.train.early_stop_patience > 0:
        callbacks.append(
            EarlyStoppingCallback(patience=config.train.early_stop_patience))

    #thanks for handling the distributed training
    '''
    we are gonna take zero_grad after accumulation accumulation_steps
    '''
    if config.train.accumulation_size > 0:
        accumulation_steps = config.train.accumulation_size // config.train.batch_size
        callbacks.extend([
            CriterionCallback(),
            OptimizerCallback(accumulation_steps=accumulation_steps)
        ])

    # to resume from check points if exists
    if os.path.exists(config.work_dir + '/checkpoints/best.pth'):
        callbacks.append(
            CheckpointCallback(resume=config.work_dir +
                               '/checkpoints/last_full.pth'))
    '''
    pudae добавь пожалуйста обратный вызов
    https://arxiv.org/pdf/1710.09412.pdf
    **srk adding the mixup callback
    '''
    if config.train.mixup:
        callbacks.append(MixupCallback())
    if config.train.cutmix:
        callbacks.append(CutMixCallback())
    '''@rajat implemented cutmix, a wieghed combination of cutout and mixup '''
    callbacks.append(MixupCallback())
    callbacks.append(CutMixCallback())
    '''
    rajat introducing training loop
    https://github.com/catalyst-team/catalyst/blob/master/catalyst/dl/runner/supervised.py
    take care of the nvidias fp16 precision
    '''
    print(config.work_dir)
    print(config.train.minimize_metric)
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=dataloaders,
        logdir=config.work_dir,
        num_epochs=config.train.num_epochs,
        main_metric=config.train.main_metric,
        minimize_metric=config.train.minimize_metric,
        callbacks=callbacks,
        verbose=True,
        fp16=False,
    )
Пример #2
0

if __name__ == "__main__":
    options = argparser()

    if not os.path.exists(options['save_dir']):
        os.makedirs(options['save_dir'])
        os.makedirs(options['log_dir'])

    best_avg_prec = 0.0
    is_best = False
    model = models.WSODModel(options)
    # model = models.WSODModel_LargerCAM(options)
    # model = models.WSODModel_LargerCAM_SameBranch(options)
    # The following return loss classes
    criterion_cls = get_loss(loss_name='CE')  # Cross-entropy loss
    if options['CAM']:
        criterion_loc = get_loss(
            loss_name='CAMLocalityLoss')  # Group sparsity penalty
    else:
        criterion_loc = get_loss(
            loss_name='LocalityLoss')  # Group sparsity penalty
    criterion_clust = get_loss(loss_name='ClusterLoss')  # MEL + BEL
    #criterion_loc_ent = get_loss(loss_name='LEL')  # Entropy type loss for locality

    model = nn.DataParallel(model).cuda()
    torch.multiprocessing.set_sharing_strategy('file_system')

    # Resume from checkpoint
    if options['resume']:
        if os.path.isfile(options['resume']):
def train_multi_task(param_file):
    with open('configs.json') as config_params:
        configs = json.load(config_params)

    with open(param_file) as json_params:
        params = json.load(json_params)

    exp_identifier = []
    for (key, val) in params.items():
        if 'tasks' in key:
            continue
        exp_identifier += ['{}={}'.format(key, val)]

    exp_identifier = '|'.join(exp_identifier)
    params['exp_id'] = exp_identifier

    writer = SummaryWriter(log_dir='runs/{}_{}'.format(
        params['exp_id'],
        datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")))

    train_loader, train_dst, val_loader, val_dst = datasets.get_dataset(
        params, configs)
    loss_fn = losses.get_loss(params)
    metric = metrics.get_metrics(params)

    model = model_selector.get_model(params)
    model_params = []
    for m in model:
        model_params += model[m].parameters()

    if 'RMSprop' in params['optimizer']:
        optimizer = torch.optim.RMSprop(model_params, lr=params['lr'])
    elif 'Adam' in params['optimizer']:
        optimizer = torch.optim.Adam(model_params, lr=params['lr'])
    elif 'SGD' in params['optimizer']:
        optimizer = torch.optim.SGD(model_params,
                                    lr=params['lr'],
                                    momentum=0.9)

    tasks = params['tasks']
    all_tasks = configs[params['dataset']]['all_tasks']
    print('Starting training with parameters \n \t{} \n'.format(str(params)))

    if 'mgda' in params['algorithm']:
        approximate_norm_solution = params['use_approximation']
        if approximate_norm_solution:
            print('Using approximate min-norm solver')
        else:
            print('Using full solver')
    n_iter = 0
    loss_init = {}
    for epoch in tqdm(range(NUM_EPOCHS)):
        start = timer()
        print('Epoch {} Started'.format(epoch))
        if (epoch + 1) % 10 == 0:
            # Every 50 epoch, half the LR
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.85
            print('Half the learning rate{}'.format(n_iter))

        for m in model:
            model[m].train()

        for batch in train_loader:
            n_iter += 1
            # First member is always images
            images = batch[0]
            images = Variable(images.cuda())

            labels = {}
            # Read all targets of all tasks
            for i, t in enumerate(all_tasks):
                if t not in tasks:
                    continue
                labels[t] = batch[i + 1]
                labels[t] = Variable(labels[t].cuda())

            # Scaling the loss functions based on the algorithm choice
            loss_data = {}
            grads = {}
            scale = {}
            mask = None
            masks = {}
            if 'mgda' in params['algorithm']:
                # Will use our MGDA_UB if approximate_norm_solution is True. Otherwise, will use MGDA

                if approximate_norm_solution:
                    optimizer.zero_grad()
                    # First compute representations (z)
                    images_volatile = Variable(images.data, volatile=True)
                    rep, mask = model['rep'](images_volatile, mask)
                    # As an approximate solution we only need gradients for input
                    if isinstance(rep, list):
                        # This is a hack to handle psp-net
                        rep = rep[0]
                        rep_variable = [
                            Variable(rep.data.clone(), requires_grad=True)
                        ]
                        list_rep = True
                    else:
                        rep_variable = Variable(rep.data.clone(),
                                                requires_grad=True)
                        list_rep = False

                    # Compute gradients of each loss function wrt z
                    for t in tasks:
                        optimizer.zero_grad()
                        out_t, masks[t] = model[t](rep_variable, None)
                        loss = loss_fn[t](out_t, labels[t])
                        loss_data[t] = loss.data[0]
                        loss.backward()
                        grads[t] = []
                        if list_rep:
                            grads[t].append(
                                Variable(rep_variable[0].grad.data.clone(),
                                         requires_grad=False))
                            rep_variable[0].grad.data.zero_()
                        else:
                            grads[t].append(
                                Variable(rep_variable.grad.data.clone(),
                                         requires_grad=False))
                            rep_variable.grad.data.zero_()
                else:
                    # This is MGDA
                    for t in tasks:
                        # Comptue gradients of each loss function wrt parameters
                        optimizer.zero_grad()
                        rep, mask = model['rep'](images, mask)
                        out_t, masks[t] = model[t](rep, None)
                        loss = loss_fn[t](out_t, labels[t])
                        loss_data[t] = loss.data[0]
                        loss.backward()
                        grads[t] = []
                        for param in model['rep'].parameters():
                            if param.grad is not None:
                                grads[t].append(
                                    Variable(param.grad.data.clone(),
                                             requires_grad=False))

                # Normalize all gradients, this is optional and not included in the paper.
                gn = gradient_normalizers(grads, loss_data,
                                          params['normalization_type'])
                for t in tasks:
                    for gr_i in range(len(grads[t])):
                        grads[t][gr_i] = grads[t][gr_i] / gn[t]

                # Frank-Wolfe iteration to compute scales.
                sol, min_norm = MinNormSolver.find_min_norm_element(
                    [grads[t] for t in tasks])
                for i, t in enumerate(tasks):
                    scale[t] = float(sol[i])
            else:
                for t in tasks:
                    masks[t] = None
                    scale[t] = float(params['scales'][t])

            # Scaled back-propagation
            optimizer.zero_grad()
            rep, _ = model['rep'](images, mask)
            for i, t in enumerate(tasks):
                out_t, _ = model[t](rep, masks[t])
                loss_t = loss_fn[t](out_t, labels[t])
                loss_data[t] = loss_t.data[0]
                if i > 0:
                    loss = loss + scale[t] * loss_t
                else:
                    loss = scale[t] * loss_t
            loss.backward()
            optimizer.step()

            writer.add_scalar('training_loss', loss.data[0], n_iter)
            for t in tasks:
                writer.add_scalar('training_loss_{}'.format(t), loss_data[t],
                                  n_iter)

        for m in model:
            model[m].eval()

        tot_loss = {}
        tot_loss['all'] = 0.0
        met = {}
        for t in tasks:
            tot_loss[t] = 0.0
            met[t] = 0.0

        num_val_batches = 0
        for batch_val in val_loader:
            val_images = Variable(batch_val[0].cuda(), volatile=True)
            labels_val = {}

            for i, t in enumerate(all_tasks):
                if t not in tasks:
                    continue
                labels_val[t] = batch_val[i + 1]
                labels_val[t] = Variable(labels_val[t].cuda(), volatile=True)

            val_rep, _ = model['rep'](val_images, None)
            for t in tasks:
                out_t_val, _ = model[t](val_rep, None)
                loss_t = loss_fn[t](out_t_val, labels_val[t])
                tot_loss['all'] += loss_t.data[0]
                tot_loss[t] += loss_t.data[0]
                metric[t].update(out_t_val, labels_val[t])
            num_val_batches += 1

        for t in tasks:
            writer.add_scalar('validation_loss_{}'.format(t),
                              tot_loss[t] / num_val_batches, n_iter)
            metric_results = metric[t].get_result()
            for metric_key in metric_results:
                writer.add_scalar('metric_{}_{}'.format(metric_key, t),
                                  metric_results[metric_key], n_iter)
            metric[t].reset()
        writer.add_scalar('validation_loss', tot_loss['all'] / len(val_dst),
                          n_iter)

        if epoch % 3 == 0:
            # Save after every 3 epoch
            state = {
                'epoch': epoch + 1,
                'model_rep': model['rep'].state_dict(),
                'optimizer_state': optimizer.state_dict()
            }
            for t in tasks:
                key_name = 'model_{}'.format(t)
                state[key_name] = model[t].state_dict()

            torch.save(
                state,
                "saved_models/{}_{}_model.pkl".format(params['exp_id'],
                                                      epoch + 1))

        end = timer()
        print('Epoch ended in {}s'.format(end - start))
def train_sim(epoch_num=10,
              optim_type='ACGD',
              startPoint=None,
              start_n=0,
              z_dim=128,
              batchsize=64,
              l2_penalty=0.0,
              momentum=0.0,
              log=False,
              loss_name='WGAN',
              model_name='dc',
              model_config=None,
              data_path='None',
              show_iter=100,
              logdir='test',
              dataname='CIFAR10',
              device='cpu',
              gpu_num=1):
    lr_d = 1e-4
    lr_g = 1e-4
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)

    optim_d = RMSprop(D.parameters(), lr=lr_d)
    optim_g = RMSprop(G.parameters(), lr=lr_g)

    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        optim_d.load_state_dict(chk['d_optim'])
        optim_g.load_state_dict(chk['g_optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name,
                            g_loss=False,
                            d_real=d_real,
                            d_fake=d_fake,
                            l2_weight=l2_penalty,
                            D=D)
            D.zero_grad()
            G.zero_grad()
            loss.backward()
            optim_d.step()
            optim_g.step()

            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs' %
                      (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(
                    path=logdir,
                    name='%s-%s%.3f_%d.pth' %
                    (optim_type, model_name, lr_g, count + start_n),
                    D=D,
                    G=G,
                    optimizer=optim_d,
                    g_optimizer=optim_g)
            if wandb and log:
                wandb.log({
                    'Real score': d_real.mean().item(),
                    'Fake score': d_fake.mean().item(),
                    'Loss': loss.item()
                })
            count += 1
Пример #5
0
def run(config_file):
    config = load_config(config_file)

    if not os.path.exists(config.work_dir):
        os.makedirs(config.work_dir, exist_ok=True)
    save_config(config, config.work_dir + '/config.yml')

    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    all_transforms = {}
    all_transforms['train'] = get_transforms(config.transforms.train)
    all_transforms['valid'] = get_transforms(config.transforms.test)

    dataloaders = {
        phase: make_loader(data_folder=config.data.train_dir,
                           df_path=config.data.train_df_path,
                           phase=phase,
                           batch_size=config.train.batch_size,
                           num_workers=config.num_workers,
                           idx_fold=config.data.params.idx_fold,
                           transforms=all_transforms[phase],
                           num_classes=config.data.num_classes,
                           pseudo_label_path=config.train.pseudo_label_path,
                           debug=config.debug)
        for phase in ['train', 'valid']
    }

    # create segmentation model with pre trained encoder
    model = getattr(smp, config.model.arch)(
        encoder_name=config.model.encoder,
        encoder_weights=config.model.pretrained,
        classes=config.data.num_classes,
        activation=None,
    )

    # train setting
    criterion = get_loss(config)
    params = [
        {
            'params': model.decoder.parameters(),
            'lr': config.optimizer.params.decoder_lr
        },
        {
            'params': model.encoder.parameters(),
            'lr': config.optimizer.params.encoder_lr
        },
    ]
    optimizer = get_optimizer(params, config)
    scheduler = get_scheduler(optimizer, config)

    # model runner
    runner = SupervisedRunner(model=model)

    callbacks = [DiceCallback(), IouCallback()]

    # to resume from check points if exists
    if os.path.exists(config.work_dir + '/checkpoints/best.pth'):
        callbacks.append(
            CheckpointCallback(resume=config.work_dir +
                               '/checkpoints/best_full.pth'))

    if config.train.mixup:
        callbacks.append(MixupCallback())

    if config.train.cutmix:
        callbacks.append(CutMixCallback())

    # model training
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=dataloaders,
        logdir=config.work_dir,
        num_epochs=config.train.num_epochs,
        callbacks=callbacks,
        verbose=True,
        fp16=True,
    )
Пример #6
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    cfg = get_config(args.config)

    try:
        world_size = int(os.environ['WORLD_SIZE'])
        rank = int(os.environ['RANK'])
        dist.init_process_group('nccl')
    except KeyError:
        world_size = 1
        rank = 0
        dist.init_process_group(backend='nccl',
                                init_method="tcp://127.0.0.1:12584",
                                rank=rank,
                                world_size=world_size)

    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)
    os.makedirs(cfg.output, exist_ok=True)
    init_logging(rank, cfg.output)

    if cfg.rec == "synthetic":
        train_set = SyntheticDataset(local_rank=local_rank)
    else:
        train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set, shuffle=True)
    train_loader = DataLoaderX(local_rank=local_rank,
                               dataset=train_set,
                               batch_size=cfg.batch_size,
                               sampler=train_sampler,
                               num_workers=2,
                               pin_memory=True,
                               drop_last=True)
    backbone = get_model(cfg.network,
                         dropout=0.0,
                         fp16=cfg.fp16,
                         num_features=cfg.embedding_size).to(local_rank)
    summary(backbone, input_size=(3, 112, 112))
    exit()

    if cfg.resume:
        try:
            backbone_pth = os.path.join(cfg.output, "backbone.pth")
            backbone.load_state_dict(
                torch.load(backbone_pth,
                           map_location=torch.device(local_rank)))
            if rank == 0:
                logging.info("backbone resume successfully!")
        except (FileNotFoundError, KeyError, IndexError, RuntimeError):
            if rank == 0:
                logging.info("resume fail, backbone init successfully!")

    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone, broadcast_buffers=False, device_ids=[local_rank])
    backbone.train()
    if cfg.loss == 'magface':
        margin_softmax = losses.get_loss(cfg.loss, lambda_g=cfg.lambda_g)
    elif cfg.loss == 'mag_cosface':
        margin_softmax = losses.get_loss(cfg.loss)
    else:
        margin_softmax = losses.get_loss(cfg.loss,
                                         s=cfg.s,
                                         m1=cfg.m1,
                                         m2=cfg.m2,
                                         m3=cfg.m3)
    module_partial_fc = PartialFC(rank=rank,
                                  local_rank=local_rank,
                                  world_size=world_size,
                                  resume=cfg.resume,
                                  batch_size=cfg.batch_size,
                                  margin_softmax=margin_softmax,
                                  num_classes=cfg.num_classes,
                                  sample_rate=cfg.sample_rate,
                                  embedding_size=cfg.embedding_size,
                                  prefix=cfg.output)

    opt_backbone = torch.optim.SGD(params=[{
        'params': backbone.parameters()
    }],
                                   lr=cfg.lr / 512 * cfg.batch_size *
                                   world_size,
                                   momentum=0.9,
                                   weight_decay=cfg.weight_decay)
    opt_pfc = torch.optim.SGD(params=[{
        'params': module_partial_fc.parameters()
    }],
                              lr=cfg.lr / 512 * cfg.batch_size * world_size,
                              momentum=0.9,
                              weight_decay=cfg.weight_decay)

    num_image = len(train_set)
    total_batch_size = cfg.batch_size * world_size
    cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
    cfg.total_step = num_image // total_batch_size * cfg.num_epoch

    def lr_step_func(current_step):
        cfg.decay_step = [
            x * num_image // total_batch_size for x in cfg.decay_epoch
        ]
        if current_step < cfg.warmup_step:
            return current_step / cfg.warmup_step
        else:
            return 0.1**len([m for m in cfg.decay_step if m <= current_step])

    scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
        optimizer=opt_backbone, lr_lambda=lr_step_func)
    scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(optimizer=opt_pfc,
                                                      lr_lambda=lr_step_func)

    for key, value in cfg.items():
        num_space = 25 - len(key)
        logging.info(": " + key + " " * num_space + str(value))

    val_target = cfg.val_targets
    callback_verification = CallBackVerification(2000, rank, val_target,
                                                 cfg.rec)
    callback_logging = CallBackLogging(50, rank, cfg.total_step,
                                       cfg.batch_size, world_size, None)
    callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)

    loss = AverageMeter()
    start_epoch = 0
    global_step = 0
    grad_amp = MaxClipGradScaler(
        cfg.batch_size, 128 *
        cfg.batch_size, growth_interval=100) if cfg.fp16 else None
    for epoch in range(start_epoch, cfg.num_epoch):
        train_sampler.set_epoch(epoch)
        for step, (img, label) in enumerate(train_loader):
            global_step += 1
            x = backbone(img)
            features = F.normalize(x)
            x_grad, loss_v = module_partial_fc.forward_backward(
                label, features, opt_pfc, x)
            if cfg.fp16:
                features.backward(grad_amp.scale(x_grad))
                grad_amp.unscale_(opt_backbone)
                clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
                grad_amp.step(opt_backbone)
                grad_amp.update()
            else:
                features.backward(x_grad)
                clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
                opt_backbone.step()

            opt_pfc.step()
            module_partial_fc.update()
            opt_backbone.zero_grad()
            opt_pfc.zero_grad()
            loss.update(loss_v, 1)
            callback_logging(global_step, loss, epoch, cfg.fp16,
                             scheduler_backbone.get_last_lr()[0], grad_amp)
            callback_verification(global_step, backbone)
            scheduler_backbone.step()
            scheduler_pfc.step()
        callback_checkpoint(global_step, backbone, module_partial_fc)

    callback_verification('last', backbone)
    dist.destroy_process_group()
def train_mnist(epoch_num=10,
                show_iter=100,
                logdir='test',
                model_weight=None,
                load_d=False,
                load_g=False,
                compare_path=None,
                info_time=100,
                run_select=None,
                dataname='CIFAR10',
                data_path='None',
                device='cpu'):
    lr_d = 0.01
    lr_g = 0.01
    batchsize = 128
    z_dim = 96
    print('MNIST, discriminator lr: %.3f, generator lr: %.3f' % (lr_d, lr_g))
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D = dc_D().to(device)
    G = dc_G(z_dim=z_dim).to(device)
    D.apply(weights_init_d)
    G.apply(weights_init_g)
    if model_weight is not None:
        chk = torch.load(model_weight)
        if load_d:
            D.load_state_dict(chk['D'])
            print('Load D from %s' % model_weight)
        if load_g:
            G.load_state_dict(chk['G'])
            print('Load G from %s' % model_weight)
    if compare_path is not None:
        discriminator = dc_D().to(device)
        model_weight = torch.load(compare_path)
        discriminator.load_state_dict(model_weight['D'])
        model_vec = torch.cat(
            [p.contiguous().view(-1) for p in discriminator.parameters()])
        print('Load discriminator from %s' % compare_path)
    if run_select is not None:
        fixed_data = torch.load(run_select)
        real_set = fixed_data['real_set']
        fake_set = fixed_data['fake_set']
        real_d = fixed_data['real_d']
        fake_d = fixed_data['fake_d']
        fixed_vec = fixed_data['pred_vec']
        print('load fixed data set')

    d_optimizer = SGD(D.parameters(), lr=lr_d)
    g_optimizer = SGD(G.parameters(), lr=lr_g)
    timer = time.time()
    count = 0
    fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            fake_x_c = fake_x.clone().detach()
            # update generator
            d_fake = D(fake_x)

            # writer.add_scalars('Discriminator output', {'Generated image': d_fake.mean().item(),
            #                                             'Real image': d_real.mean().item()},
            #                    global_step=count)
            G_loss = get_loss(name='JSD', g_loss=True, d_fake=d_fake)
            g_optimizer.zero_grad()
            G_loss.backward()
            g_optimizer.step()
            gg = torch.norm(torch.cat(
                [p.grad.contiguous().view(-1) for p in G.parameters()]),
                            p=2)

            d_fake_c = D(fake_x_c)
            D_loss = get_loss(name='JSD',
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake_c)
            if compare_path is not None and count % info_time == 0:
                diff = get_diff(net=D, model_vec=model_vec)
                # writer.add_scalar('Distance from checkpoint', diff.item(), global_step=count)
                if run_select is not None:
                    with torch.no_grad():
                        d_real_set = D(real_set)
                        d_fake_set = D(fake_set)
                        diff_real = torch.norm(d_real_set - real_d, p=2)
                        diff_fake = torch.norm(d_fake_set - fake_d, p=2)
                        d_vec = torch.cat([d_real_set, d_fake_set])
                        diff = torch.norm(d_vec.sub_(fixed_vec), p=2)
                        # writer.add_scalars('L2 norm of pred difference',
                        #                    {'Total': diff.item(),
                        #                     'real set': diff_real.item(),
                        #                     'fake set': diff_fake.item()},
                        #                    global_step=count)
            d_optimizer.zero_grad()
            D_loss.backward()
            d_optimizer.step()
            gd = torch.norm(torch.cat(
                [p.grad.contiguous().view(-1) for p in D.parameters()]),
                            p=2)
            # writer.add_scalars('Loss', {'D_loss': D_loss.item(),
            #                             'G_loss': G_loss.item()}, global_step=count)
            # writer.add_scalars('Grad', {'D grad': gd.item(),
            #                             'G grad': gg.item()}, global_step=count)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , D_loss: %.5f, G_loss: %.5f, time: %.3fs' %
                      (count, D_loss.item(), G_loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s/' % logdir
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % count,
                                      normalize=True)
                save_checkpoint(path=logdir,
                                name='SGD-%.3f_%d.pth' % (lr_d, count),
                                D=D,
                                G=G)
            count += 1
Пример #8
0
def train_cgd(epoch_num=10,
              milestone=None,
              optim_type='ACGD',
              startPoint=None,
              start_n=0,
              z_dim=128,
              batchsize=64,
              tols={
                  'tol': 1e-10,
                  'atol': 1e-16
              },
              l2_penalty=0.0,
              momentum=0.0,
              loss_name='WGAN',
              model_name='dc',
              model_config=None,
              data_path='None',
              show_iter=100,
              logdir='test',
              dataname='CIFAR10',
              device='cpu',
              gpu_num=1,
              collect_info=False):
    lr_d = 0.01
    lr_g = 0.01
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    writer = SummaryWriter(log_dir='logs/%s/%s_%.3f' %
                           (logdir, current_time, lr_d))
    if optim_type == 'BCGD':
        optimizer = BCGD(max_params=G.parameters(),
                         min_params=D.parameters(),
                         lr_max=lr_g,
                         lr_min=lr_d,
                         momentum=momentum,
                         tol=tols['tol'],
                         atol=tols['atol'],
                         device=device)
        scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    elif optim_type == 'ICR':
        optimizer = ICR(max_params=G.parameters(),
                        min_params=D.parameters(),
                        lr=lr_d,
                        alpha=1.0,
                        device=device)
        scheduler = icrScheduler(optimizer, milestone)
    elif optim_type == 'ACGD':
        optimizer = ACGD(max_params=G.parameters(),
                         min_params=D.parameters(),
                         lr_max=lr_g,
                         lr_min=lr_d,
                         tol=tols['tol'],
                         atol=tols['atol'],
                         device=device,
                         solver='cg')
        scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        optimizer.load_state_dict(chk['optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        scheduler.step(epoch=e)
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name,
                            g_loss=False,
                            d_real=d_real,
                            d_fake=d_fake,
                            l2_weight=l2_penalty,
                            D=D)
            optimizer.zero_grad()
            optimizer.step(loss)

            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs' %
                      (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(
                    path=logdir,
                    name='%s-%s%.3f_%d.pth' %
                    (optim_type, model_name, lr_g, count + start_n),
                    D=D,
                    G=G,
                    optimizer=optimizer)
            writer.add_scalars('Discriminator output', {
                'Generated image': d_fake.mean().item(),
                'Real image': d_real.mean().item()
            },
                               global_step=count)
            writer.add_scalar('Loss', loss.item(), global_step=count)
            if collect_info:
                cgd_info = optimizer.get_info()
                writer.add_scalar('Conjugate Gradient/iter num',
                                  cgd_info['iter_num'],
                                  global_step=count)
                writer.add_scalar('Conjugate Gradient/running time',
                                  cgd_info['time'],
                                  global_step=count)
                writer.add_scalars('Delta', {
                    'D gradient': cgd_info['grad_y'],
                    'G gradient': cgd_info['grad_x'],
                    'D hvp': cgd_info['hvp_y'],
                    'G hvp': cgd_info['hvp_x'],
                    'D cg': cgd_info['cg_y'],
                    'G cg': cgd_info['cg_x']
                },
                                   global_step=count)
            count += 1
    writer.close()
Пример #9
0
def train(opt, netG, netD, optim_G, optim_D):
    tensor = torch.cuda.FloatTensor
    # lossD_list = []
    # lossG_list = []

    train = ReadConcat(opt)
    trainset = DataLoader(train, batch_size=opt.batchSize, shuffle=True)
    save_img_path = os.path.join('./result', 'train')
    check_folder(save_img_path)

    for e in range(opt.epoch, opt.niter + opt.niter_decay + 1):
        for i, data in enumerate(trainset):
            # set input
            data_A = data['A']  # blur
            data_B = data['B']  #sharp
            # plt.imshow(image_recovery(data_A.squeeze().numpy()))
            # plt.pause(0)
            # print(data_A.shape)
            # print(data_B.shape)

            if torch.cuda.is_available():
                data_A = data_A.cuda(opt.gpu)
                data_B = data_B.cuda(opt.gpu)
            # forward
            realA = Variable(data_A)
            fakeB = netG(realA)
            realB = Variable(data_B)

            # optimize_parameters
            # optimizer netD
            set_requires_grad([netD], True)
            for iter_d in range(1):
                optim_D.zero_grad()
                loss_D, _ = get_loss(tensor, netD, realA, fakeB, realB)
                loss_D.backward()
                optim_D.step()

            # optimizer netG
            set_requires_grad([netD], False)
            optim_G.zero_grad()
            _, loss_G = get_loss(tensor, netD, realA, fakeB, realB)
            loss_G.backward()
            optim_G.step()
            if i % 50 == 0:
                # lossD_list.append(loss_D)
                # lossG_list.append(loss_G)
                print('{}/{}: lossD:{}, lossG:{}'.format(i, e, loss_D, loss_G))

        visul_img = torch.cat((realA, fakeB, realA), 3)
        #print(type(visul_img), visul_img.size())
        visul_img = image_recovery(visul_img)
        #print(visul_img.size)
        save_image(visul_img,
                   os.path.join(save_img_path, 'epoch' + str(e) + '.png'))

        if e > opt.niter:
            update_lr(optim_D, opt.lr, opt.niter_decay)
            lr = (optim_G, opt.lr, opt.niter_decay)
            opt.lr = lr

        if e % opt.save_epoch_freq == 0:
            save_net(netG, opt.checkpoints_dir, 'G', e)
            save_net(netD, opt.checkpoints_dir, 'D', e)
Пример #10
0
def train(config, tols, milestone, n=2, device='cpu'):
    lr_d = config['lr_d']
    lr_g = config['lr_g']
    optim_type = config['optimizer']
    z_dim = config['z_dim']
    model_name = config['model']
    epoch_num = config['epoch_num']
    show_iter = config['show_iter']
    loss_name = config['loss_type']
    l2_penalty = config['d_penalty']
    logdir = config['logdir']
    start_n = config['startn']
    dataset = get_data(dataname=config['dataset'], path='../datas/%s' % config['datapath'])
    dataloader = DataLoader(dataset=dataset, batch_size=config['batchsize'],
                            shuffle=True, num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)
    if optim_type == 'SCGD':
        optimizer = SCGD(max_params=G.parameters(), min_params=D.parameters(),
                         lr_max=lr_g, lr_min=lr_d,
                         tol=tols['tol'], atol=tols['atol'],
                         device=device, solver='cg')
        scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    if config['checkpoint'] is not None:
        startPoint = config['checkpoint']
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        optimizer.load_state_dict(chk['optim'])
        print('Start from %s' % startPoint)
    gpu_num = config['gpu_num']
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0

    if model_name == 'DCGAN' or model_name == 'DCGAN-WBN':
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        scheduler.step(epoch=e)
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            optimizer.zero_grad()
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if model_name == 'DCGAN' or model_name == 'DCGAN-WBN':
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name, g_loss=False,
                            d_real=d_real, d_fake=d_fake,
                            l2_weight=l2_penalty, D=D)
            optimizer.step(loss)
            if (count + 1) % n == 0:
                optimizer.update(n)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs'
                      % (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (config['dataset'], logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img, path + 'iter_%d.png' % (count + start_n), normalize=True)
                save_checkpoint(path=logdir,
                                name='%s-%s%.3f_%d.pth' % (optim_type, model_name, lr_g, count + start_n),
                                D=D, G=G, optimizer=optimizer)
            count += 1
Пример #11
0
    model(tf.random.uniform((1, args.height, args.width, 3), dtype=tf.float32))
    if args.load_model:
        if os.path.exists(os.path.join(args.load_model, "saved_model.pb")):
            pretrained_model = K.models.load_model(args.load_model)
            model.set_weights(pretrained_model.get_weights())
            print("Model loaded from {} successfully".format(
                os.path.basename(args.load_model)))
        else:
            print("No file found at {}".format(
                os.path.join(args.load_model, "saved_model.pb")))

total_steps = 0
step = 0
curr_step = 0

calc_loss = losses.get_loss(name=args.loss)
cross_entropy_loss = losses.get_loss(name="cross_entropy")


def train_step(mini_batch, aux=False, pick=None):
    with tf.GradientTape() as tape:
        train_logits = model(mini_batch[0], training=True)
        train_labs = tf.one_hot(mini_batch[1][..., 0], classes)
        if aux:
            losses = [
                tf.reduce_mean(
                    calc_loss(
                        train_labs,
                        tf.image.resize(train_logit,
                                        size=train_labs.shape[1:3])))
                if n == 0 else args.aux_weight * tf.reduce_mean(
Пример #12
0
# TODO: Add a proper way of handling logs in absence of validation or test data
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model, iterator=dataset_train)
manager = tf.train.CheckpointManager(ckpt,  logdir + "/models/", max_to_keep=10)

# writer.set_as_default()
step = 0
for epoch in range(epochs):
    for step, (mini_batch, val_mini_batch) in enumerate(zip(dataset_train, dataset_test)):
        train_probs = tf.nn.softmax(model(mini_batch['image'] / 255))
        train_labs = tf.one_hot(mini_batch['label'], n_classes)
        val_probs = tf.nn.softmax(model(val_mini_batch['image'] / 255))
        val_labs = tf.one_hot(val_mini_batch['label'], n_classes)

        loss = train_step(mini_batch['image'] / 255, train_labs, model, optimizer)
        val_loss = losses.get_loss(val_probs,
                                   val_labs,
                                   name='cross_entropy',
                                   from_logits=False)
        print("Epoch {}: {}/{}, Loss: {} Val Loss: {}".format(epoch, step * batch_size, total_samples, loss.numpy(),
                                                              val_loss.numpy()), end='     \r', flush=True)
        curr_step = total_steps + step
        if curr_step % log_freq == 0:
            with train_writer.as_default():
                tf.summary.scalar("loss", loss,
                                  step=curr_step)
            with test_writer.as_default():
                tf.summary.scalar("loss", val_loss,
                                  step=curr_step)
            for t_metric, v_metric in zip(train_metrics, val_metrics):
                _, _ = t_metric.update_state(mini_batch['label'], tf.argmax(train_probs, axis=-1)), \
                       v_metric.update_state(val_mini_batch['label'], tf.argmax(val_probs, axis=-1))
                with train_writer.as_default():
Пример #13
0
    parser.add_argument('--lr_step', type=int, default=150)
    parser.add_argument('--start_epoch', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=350)
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--print_freq', type=int, default=10)
    parser.add_argument('--sup_indices_file', type=str, default=None)
    parser.add_argument('--temp_file', type=str, default='temp/sup_indices.json')

    options = vars(parser.parse_args())
    return options

if __name__ == "__main__":
    options = argparser()

    # The following return loss classes
    criterion_cls = get_loss(loss_name='CE') # Cross-entropy loss
    criterion_clust = get_loss(loss_name='ClusterLoss') # MEL + BEL
    criterion_reg = get_loss(loss_name='STLoss') # Regularization with Stochastic Transformations
                                                 # loss

    torch.multiprocessing.set_sharing_strategy('file_system')
    cudnn.benchmark = True

    if options['mode'] == 'train':
        if options['type'] == 'cls_clust':
            # Use only classification and clustering (MEL + BEL)
            options['gamma'] = 0
            options['delta'] = 0
        elif options['type'] == 'cls':
            # Use only classification
            options['alpha'] = 0
Пример #14
0
def main(config_path, gpu='0'):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    config = get_config(config_path)
    model_name = config['model_name']
    val_fold = config['val_fold']
    folds_to_use = config['folds_to_use']
    alias = config['alias']
    log_path = osp.join(config['logs_path'],
                        alias + str(val_fold) + '_' + model_name)

    device = torch.device(config['device'])
    weights = config['weights']
    loss_name = config['loss']
    optimizer_name = config['optimizer']
    lr = config['lr']
    decay = config['decay']
    momentum = config['momentum']
    epochs = config['epochs']
    fp16 = config['fp16']
    n_classes = config['n_classes']
    input_channels = config['input_channels']
    main_metric = config['main_metric']

    best_models_count = config['best_models_count']
    minimize_metric = config['minimize_metric']

    folds_file = config['folds_file']
    train_augs = config['train_augs']
    preprocessing_fn = config['preprocessing_fn']
    limit_files = config['limit_files']
    batch_size = config['batch_size']
    shuffle = config['shuffle']
    num_workers = config['num_workers']
    valid_augs = config['valid_augs']
    val_batch_size = config['val_batch_size']
    multiplier = config['multiplier']

    train_dataset = SemSegDataset(mode='train',
                                  n_classes=n_classes,
                                  folds_file=folds_file,
                                  val_fold=val_fold,
                                  folds_to_use=folds_to_use,
                                  augmentation=train_augs,
                                  preprocessing=preprocessing_fn,
                                  limit_files=limit_files,
                                  multiplier=multiplier)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=shuffle,
                              num_workers=num_workers)

    valid_dataset = SemSegDataset(mode='valid',
                                  folds_file=folds_file,
                                  n_classes=n_classes,
                                  val_fold=val_fold,
                                  folds_to_use=folds_to_use,
                                  augmentation=valid_augs,
                                  preprocessing=preprocessing_fn,
                                  limit_files=limit_files)

    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=val_batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    model = make_model(model_name=model_name).to(device)

    loss = get_loss(loss_name=loss_name)
    optimizer = get_optimizer(optimizer_name=optimizer_name,
                              model=model,
                              lr=lr,
                              momentum=momentum,
                              decay=decay)

    if config['scheduler'] == 'steps':
        print('steps lr')
        steps = config['steps']
        step_gamma = config['step_gamma']
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                         milestones=steps,
                                                         gamma=step_gamma)
    callbacks = []

    dice_callback = DiceCallback()
    callbacks.append(dice_callback)
    callbacks.append(CheckpointCallback(save_n_best=best_models_count))

    runner = SupervisedRunner(device=device)
    loaders = {'train': train_loader, 'valid': valid_loader}

    runner.train(model=model,
                 criterion=loss,
                 optimizer=optimizer,
                 loaders=loaders,
                 scheduler=scheduler,
                 callbacks=callbacks,
                 logdir=log_path,
                 num_epochs=epochs,
                 verbose=True,
                 main_metric=main_metric,
                 minimize_metric=minimize_metric,
                 fp16=fp16)
Пример #15
0
def main():
    ''' simple starter program for tensorflow models. '''
    logging_format = '%(asctime)s %(levelname)s:%(process)s:%(thread)s:%(name)s:%(message)s'
    logging_datefmt = '%Y-%m-%d %H:%M:%S'
    logging_level = logging.INFO

    parser = argparse.ArgumentParser(description='')
    parser.add_argument(
        '-c',
        '--config',
        dest='config_filename',
        help='configuration filename in json format [default: %s]' %
        DEFAULT_CONFIG,
        default=DEFAULT_CONFIG)
    parser.add_argument(
        '--interop',
        type=int,
        help=
        'set Tensorflow "inter_op_parallelism_threads" session config varaible [default: %s]'
        % DEFAULT_INTEROP,
        default=DEFAULT_INTEROP)
    parser.add_argument(
        '--intraop',
        type=int,
        help=
        'set Tensorflow "intra_op_parallelism_threads" session config varaible [default: %s]'
        % DEFAULT_INTRAOP,
        default=DEFAULT_INTRAOP)
    parser.add_argument(
        '-l',
        '--logdir',
        default=DEFAULT_LOGDIR,
        help='define location to save log information [default: %s]' %
        DEFAULT_LOGDIR)

    parser.add_argument('--horovod',
                        dest='horovod',
                        default=False,
                        action='store_true',
                        help="Use horovod")

    parser.add_argument('--debug',
                        dest='debug',
                        default=False,
                        action='store_true',
                        help="Set Logger to DEBUG")
    parser.add_argument('--error',
                        dest='error',
                        default=False,
                        action='store_true',
                        help="Set Logger to ERROR")
    parser.add_argument('--warning',
                        dest='warning',
                        default=False,
                        action='store_true',
                        help="Set Logger to ERROR")
    parser.add_argument('--logfilename',
                        dest='logfilename',
                        default=None,
                        help='if set, logging information will go to file')
    args = parser.parse_args()

    hvd = None
    if args.horovod:
        import horovod
        import horovod.tensorflow as hvd
        hvd.init()
        logging_format = '%(asctime)s %(levelname)s:%(process)s:%(thread)s:' + (
            '%05d' % hvd.rank()) + ':%(name)s:%(message)s'

        if hvd.rank() > 0:
            logging_level = logging.WARNING

    if args.debug and not args.error and not args.warning:
        logging_level = logging.DEBUG
        os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '0'
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
    elif not args.debug and args.error and not args.warning:
        logging_level = logging.ERROR
    elif not args.debug and not args.error and args.warning:
        logging_level = logging.WARNING

    logging.basicConfig(level=logging_level,
                        format=logging_format,
                        datefmt=logging_datefmt,
                        filename=args.logfilename)

    if 'CUDA_VISIBLE_DEVICES' in os.environ:
        logging.warning('CUDA_VISIBLE_DEVICES=%s %s',
                        os.environ['CUDA_VISIBLE_DEVICES'],
                        device_lib.list_local_devices())
    else:
        logging.info('CUDA_VISIBLE_DEVICES not defined in os.environ')
    logging.info('using tensorflow version:   %s', tf.__version__)
    logging.info('using tensorflow from:      %s', tf.__file__)
    if hvd:
        logging.warning(
            'rank: %5d   size: %5d  local rank: %5d  local size: %5d',
            hvd.rank(), hvd.size(), hvd.local_rank(), hvd.local_size())

        logging.info('using horovod version:      %s', horovod.__version__)
        logging.info('using horovod from:         %s', horovod.__file__)
    logging.info('logdir:                     %s', args.logdir)
    logging.info('interop:                    %s', args.interop)
    logging.info('intraop:                    %s', args.intraop)

    device_str = '/CPU:0'
    if tf.test.is_gpu_available():
        # device_str = '/device:GPU:' + str(hvd.local_rank())
        gpus = tf.config.experimental.list_logical_devices('GPU')
        logger.warning('gpus = %s', gpus)
        # assert hvd.local_rank() < len(gpus), f'localrank = {hvd.local_rank()} len(gpus) = {len(gpus)}'
        device_str = gpus[0].name
        # logger.info('device_str = %s',device_str)

    logger.warning('device:                     %s', device_str)

    config = json.load(open(args.config_filename))
    config['device'] = device_str
    config['hvd'] = hvd

    logger.info('-=-=-=-=-=-=-=-=-  CONFIG FILE -=-=-=-=-=-=-=-=-')
    logger.info('%s = \n %s', args.config_filename,
                json.dumps(config, indent=4, sort_keys=True))
    logger.info('-=-=-=-=-=-=-=-=-  CONFIG FILE -=-=-=-=-=-=-=-=-')
    config['hvd'] = hvd

    with tf.Graph().as_default():
        logger.info('getting datasets')
        trainds, validds = data_handler.get_datasets(config)

        input_shape = (config['data']['batch_size'], ) + tuple(
            config['data']['image_shape'])
        target_shape = (config['data']['batch_size'],
                        config['data']['image_shape'][0])

        iterator = tf.compat.v1.data.Iterator.from_structure(
            (tf.float32, tf.int32), (input_shape, target_shape))
        input, target = iterator.get_next()
        training_init_op = iterator.make_initializer(trainds)
        valid_init_op = iterator.make_initializer(validds)

        with tf.device(device_str):

            is_training_pl = tf.compat.v1.placeholder(tf.bool, shape=())
            batch = tf.Variable(0)
            batch_size = tf.constant(config['data']['batch_size'])

            pred, endpoints = model.get_model(input, is_training_pl, config)
            logger.info('pred = %s  target = %s', pred.shape, target.shape)
            # pred = BxC, target = BxC
            loss = losses.get_loss(config)(labels=target, logits=pred)
            #tf.compat.v1.summary.scalar('loss/combined',loss)

            #learning_rate = pointnet_seg.get_learning_rate(batch,config) * hvd.size()
            learning_rate = lr_func.get_learning_rate(batch * batch_size,
                                                      config)
            tf.compat.v1.summary.scalar('learning_rate', learning_rate)
            if config['optimizer']['name'] == 'adam':
                optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)

            # adding Horovod distributed optimizer
            if hvd:
                optimizer = hvd.DistributedOptimizer(optimizer)

            # create the training operator
            train_op = optimizer.minimize(loss, global_step=batch)

            # Add ops to save and restore all the variables.
            saver = tf.compat.v1.train.Saver()

            merged = tf.compat.v1.summary.merge_all()

        logger.info('create session')

        config_proto = tf.compat.v1.ConfigProto()
        if 'gpu' in device_str:
            config_proto.gpu_options.allow_growth = True
            config_proto.gpu_options.visible_device_list = os.environ[
                'CUDA_VISIBLE_DEVICES']
        else:
            config_proto.allow_soft_placement = True
            config_proto.intra_op_parallelism_threads = args.intraop
            config_proto.inter_op_parallelism_threads = args.interop

        # Initialize an iterator over a dataset with 10 elements.
        sess = tf.compat.v1.Session(config=config_proto)

        # create tensorboard writers
        if hvd and hvd.rank() == 0:
            train_writer = tf.compat.v1.summary.FileWriter(
                os.path.join(args.logdir, 'train'), sess.graph)
            valid_writer = tf.compat.v1.summary.FileWriter(
                os.path.join(args.logdir, 'valid'), sess.graph)

        # initialize global vars and horovod broadcast initial model
        init = tf.compat.v1.global_variables_initializer()
        sess.run(init, {is_training_pl: True})
        if hvd:
            sess.run(hvd.broadcast_global_variables(0))

        logger.info('running over data')
        status_interval = config['training']['status']
        loss_sum = 0.
        for epoch in range(config['training']['epochs']):
            logger.info('epoch %s of %s', epoch + 1,
                        config['training']['epochs'])

            # initialize the data iterator for training loop
            sess.run(training_init_op)

            # training loop
            start = time.time()
            while True:
                try:
                    # set that we are training
                    feed_dict = {is_training_pl: True}
                    summary, step, _, loss_val = sess.run(
                        [merged, batch, train_op, loss], feed_dict=feed_dict)

                    # report status periodically
                    if step % status_interval == 0:
                        end = time.time()
                        duration = end - start
                        logger.info(
                            'step: %10d    imgs/sec: %10.6f', step,
                            float(status_interval) *
                            config['data']['batch_size'] / duration)
                        start = time.time()

                # exception thrown when data is done
                except tf.errors.OutOfRangeError:
                    logger.info(' end of epoch ')
                    saver.save(sess,
                               os.path.join(args.logdir, "model.ckpt"),
                               global_step=step)
                    break

            logger.info('running validation')
            # initialize the validation data iterator
            sess.run(valid_init_op)

            steps = 0.
Пример #16
0
def train_multi_task(param_file,
                     if_debug,
                     conn_counts_file,
                     overwrite_lr=None,
                     overwrite_lambda_reg=None,
                     overwrite_weight_decay=None):
    # print("Approx. optimal weights 0.89 0.01 0.1 (S, I, D) - from https://arxiv.org/pdf/1705.07115.pdf")
    with open(param_file) as json_params:
        params = json.load(json_params)

    if params['input_size'] == 'default':
        config_path = 'configs.json'
    elif params['input_size'] == 'bigimg':
        config_path = 'configs_big_img.json'
    elif params['input_size'] == 'biggerimg':
        config_path = 'configs_bigger_img.json'

    with open(config_path) as config_params:
        configs = json.load(config_params)

    def get_log_dir_name(params):
        exp_identifier = []
        for (key, val) in params.items():
            if 'tasks' in key or 'scales' in key:
                continue
            exp_identifier += ['{}={}'.format(key, val)]

        exp_identifier = '|'.join(exp_identifier)
        params['exp_id'] = exp_identifier

        run_dir_name = 'runs_debug' if if_debug else 'runsB'
        time_str = datetime.datetime.now().strftime("%H_%M_on_%B_%d")
        log_dir_name = '/mnt/antares_raid/home/awesomelemon/{}/{}'.format(
            run_dir_name, time_str)

        def print_proper_log_dir_name():
            log_dir_name_full = '/mnt/antares_raid/home/awesomelemon/{}/{}_{}'.format(
                run_dir_name, params['exp_id'], time_str)
            log_dir_name_full = re.sub(r'\s+', '_', log_dir_name_full)
            log_dir_name_full = re.sub(r"'", '_', log_dir_name_full)
            log_dir_name_full = re.sub(r'"', '_', log_dir_name_full)
            log_dir_name_full = re.sub(r':', '_', log_dir_name_full)
            log_dir_name_full = re.sub(r',', '|', log_dir_name_full)
            print(log_dir_name_full)

        print_proper_log_dir_name()

        return log_dir_name, time_str

    log_dir_name, time_str = get_log_dir_name(params)
    print(f'Log dir: {log_dir_name}')

    writer = SummaryWriter(log_dir=log_dir_name)

    train_loader, val_loader, train2_loader = datasets.get_dataset(
        params, configs)

    loss_fn = losses.get_loss(params)
    metric = metrics.get_metrics(params)

    if_scale_by_blncd_acc = 'if_scale_by_blncd_acc' in params
    if if_scale_by_blncd_acc:
        blncd_accs_val = defaultdict(lambda: 1)
        for task, loss_fn_task in list(loss_fn.items()):
            loss_fn[task] = lambda pred, gt: loss_fn_task(
                pred, gt, 100**(1 - blncd_accs_val[task]))

    tasks = params['tasks']
    all_tasks = configs[params['dataset']]['all_tasks']

    arc = params['architecture']
    if_train_default_resnet = 'vanilla' in arc
    model = model_selector_automl.get_model(params)
    # summary(model['rep'], input_size=[(3, 64, 64), (1,)])
    # model = model_selector_plainnet.get_model(params)

    if_use_pretrained_resnet = params['use_pretrained_resnet']
    if if_use_pretrained_resnet:
        model_rep_dict = model['rep'].state_dict()

        def rename_key(key):
            key = re.sub('downsample', 'shortcut', key)
            if False:
                key = re.sub('(layer[34].0.conv1.)', '\g<0>ordinary_conv.',
                             key)
            else:
                if ('layer1.0.conv1' not in key) and (
                        'layer1.0.conv2' not in key
                ):  # a hack because it's 1 AM, this is the only layer which shouldn't be replaced in the 8-block system
                    key = re.sub('(layer[1234].[01].conv[12].)',
                                 '\g<0>ordinary_conv.', key)
                    key = re.sub(
                        '(layer[234].0.shortcut.0.)', '\g<0>ordinary_conv.',
                        key
                    )  #"shortcut" because applied after replacing "downsample"
            return key

        if 'celeba' in params['dataset']:
            pretrained_dict = torchvision.models.resnet18(
                pretrained=True).state_dict()
            for k, v in pretrained_dict.items():
                if (rename_key(k)
                        not in model_rep_dict) and (k != 'conv1.weight'):
                    print('ACHTUNG! Following pretrained weight was ignored: ',
                          rename_key(k))
            pretrained_dict = {
                rename_key(k): v
                for k, v in pretrained_dict.items()
                if rename_key(k) in model_rep_dict and k != 'conv1.weight'
            }
        elif 'cityscapes' in params['dataset']:
            pretrained_dict = torchvision.models.resnet50(
                pretrained=True).state_dict()
            for k, v in pretrained_dict.items():
                if (rename_key(k)
                        not in model_rep_dict) and (k != 'conv1.weight'):
                    print('ACHTUNG! Following pretrained weight was ignored: ',
                          rename_key(k))
            pretrained_dict = {
                rename_key(k): v
                for k, v in pretrained_dict.items()
                if rename_key(k) in model_rep_dict
            }

        # model_rep_dict.update(pretrained_dict)
        pretrained_dict['conv1.weight'] = model_rep_dict[
            'conv1.weight']  # this is the only difference between them, as of now. If there are any missing or extraneous keys, Pytorch throws an exception
        #actually, after I enabled biases, which are not in pretrained dict, I need to add them too
        if params['if_enable_bias']:
            for op_name, op_weight in model_rep_dict.items():
                if ('bias' in op_name) and ('bn' not in op_name):
                    print(op_name)
                    pretrained_dict[op_name] = op_weight
        model['rep'].load_state_dict(pretrained_dict)

    if_continue_training = False
    if if_continue_training:
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/18_10_on_December_06/optimizer=Adam|batch_size=256|lr=0.0005|lambda_reg=0.001|dataset=celeba|normalization_type=none|algorithm=no_smart_gradient_stuff|use_approximation=True|scales={_0___0.025|__1___0.025|__2___0.025|__3_4_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/16_04_on_December_10/optimizer=Adam|batch_size=256|lr=0.0005|lambda_reg=0.0001|dataset=celeba|normalization_type=none|algorithm=no_smart_gradient_stuff|use_approximation=True|scales={_0___0.025|__1___0.025|__2___0.025|___5_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/11_10_on_December_11/optimizer=Adam|batch_size=256|lr=0.0005|lambda_reg=0.001|dataset=celeba|normalization_type=none|algorithm=no_smart_gradient_stuff|use_approximation=True|scales={_0___0.025|__1___0.025|__2___0.025|__3_10_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/06_58_on_February_26/optimizer=Adam|batch_size=256|lr=0.0005|lambda_reg=0.0001|chunks=[1|_1|_16]|architecture=resnet18|width_mul=1|dataset=celeba|normalization_type=none|algorithm=no_smart_gradient_stuff|use_approximatio_4_model.pkl'
        # state = torch.load(save_model_path)
        # model['rep'].load_state_dict(state['model_rep'])
        #
        # for t in tasks:
        #     key_name = 'model_{}'.format(t)
        #     model[t].load_state_dict(state[key_name])
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/12_42_on_April_17/optimizer=SGD|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[8|_8|_8]|architecture=binmatr2_resnet18|width_mul=1|weight_decay=0.0|connectivities_l1=0.0|if_fully_connected=True|use_pretrained_17_model.pkl'
        # param_file = 'params/binmatr2_8_8_8_sgd001_pretrain_fc_consontop.json'
        save_model_path = r'/mnt/raid/data/chebykin/saved_models/16_51_on_May_21/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_16_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/20_46_on_June_08/optimizer=SGD_Adam|batch_size=96|lr=0.004|connectivities_lr=0.0|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_decay_22_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/12_55_on_June_13/optimizer=SGD_Adam|batch_size=96|lr=0.004|connectivities_lr=0.0|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_decay_28_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/23_03_on_June_11/optimizer=SGD_Adam|batch_size=96|lr=0.004|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_60_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/22_07_on_June_22/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_90_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/00_50_on_June_24/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/22_43_on_June_24/optimizer=SGD_Adam|batch_size=96|lr=0.004|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_90_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/21_22_on_June_26/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_180_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/19_15_on_June_28/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_180_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/00_50_on_June_24/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/12_18_on_June_24/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_46_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/13_05_on_August_13/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/00_18_on_August_14/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/13_58_on_August_14/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_91_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/18_58_on_August_22/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0003|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/12_26_on_August_29/optimizer=Adam|batch_size=256|lr=0.0005|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_deca_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/14_50_on_August_31/optimizer=Adam|batch_size=256|lr=0.0005|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_deca_37_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/02_15_on_September_01/optimizer=Adam|batch_size=256|lr=0.0005|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_deca_100_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/12_07_on_September_01/optimizer=Adam|batch_size=256|lr=0.0005|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_deca_82_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/12_07_on_September_01/optimizer=Adam|batch_size=256|lr=0.0005|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_deca_55_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/11_28_on_September_02/optimizer=Adam|batch_size=256|lr=0.0005|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_deca_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/12_46_on_September_06/optimizer=SGD_Adam|batch_size=256|lr=0.005|connectivities_lr=0.001|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/20_28_on_September_21/optimizer=SGD_Adam|batch_size=128|lr=0.1|connectivities_lr=0.001|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_deca_145_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/11_49_on_November_18/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_46_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/14_57_on_November_18/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_19_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/14_53_on_November_19/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_115_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/21_08_on_November_19/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_120_model.pkl'
        # save_model_path = r'/mnt/raid/data/chebykin/saved_models/11_18_on_November_20/optimizer=SGD_Adam|batch_size=256|lr=0.01|connectivities_lr=0.0005|chunks=[64|_64|_64|_128|_128|_128|_128|_256|_256|_256|_256|_512|_512|_512|_512]|architecture=binmatr2_resnet18|width_mul=1|weight_de_120_model.pkl'
        print('Continuing training from the following path:')
        print(save_model_path)
        model = load_trained_model(param_file, save_model_path)
        state = torch.load(save_model_path)
        # print('Disabling learning of connectivities!')
        # for conn in model['rep'].connectivities:
        #     conn.requires_grad = False

    if 'prune' in params:
        prune_ratio = params['prune']

        def prune(model, prune_ratio):
            # prune all convs except the first three (conv1, layer0.conv1, layer0.conv2)
            mr = model['rep']
            # convs = list([layer for layer in model['rep'].modules() if isinstance(layer, torch.nn.Conv2d)])[3:]
            convs = [  #mr.conv1, mr.layer1[0].conv1, mr.layer1[0].conv2,
                mr.layer1[1].conv1.ordinary_conv,
                mr.layer1[1].conv2.ordinary_conv,
                mr.layer2[0].conv1.ordinary_conv,
                mr.layer2[0].conv2.ordinary_conv,
                mr.layer2[0].shortcut[0].ordinary_conv,
                mr.layer2[1].conv1.ordinary_conv,
                mr.layer2[1].conv2.ordinary_conv,
                mr.layer3[0].conv1.ordinary_conv,
                mr.layer3[0].conv2.ordinary_conv,
                mr.layer3[0].shortcut[0].ordinary_conv,
                mr.layer3[1].conv1.ordinary_conv,
                mr.layer3[1].conv2.ordinary_conv,
                mr.layer4[0].conv1.ordinary_conv,
                mr.layer4[0].conv2.ordinary_conv,
                mr.layer4[0].shortcut[0].ordinary_conv,
                mr.layer4[1].conv1.ordinary_conv,
                mr.layer4[1].conv2.ordinary_conv,
            ]
            # also prune task heads
            heads = []
            for m in model:
                if m == 'rep':
                    continue
                heads.append(model[m].linear)
            to_prune = convs + heads
            to_prune = list(zip(to_prune, ['weight'] * len(to_prune)))
            import torch.nn.utils.prune as prune
            prune.global_unstructured(
                to_prune,
                pruning_method=prune.L1Unstructured,
                amount=prune_ratio,
            )

        prune(model, prune_ratio[0])
    else:
        print(
            'Remember that due to pruning MaskedConv2d could be manually set to normal convolution mode'
        )

    model_params = []
    if_freeze_normal_params_only = params['freeze_all_but_conns']

    # print('setting BatchNorm to eval')
    def set_bn_to_eval(m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.eval()

    # for m in list(model.keys()):
    #     if m != '12' and m != 'rep':
    #         print('This is temporary to figure out whether features are lost or just deactivated when a task is disowned')
    #         del model[m]
    #         continue
    for m in model:
        if if_freeze_normal_params_only:
            model[m].apply(set_bn_to_eval)

        cur_params = list(model[m].parameters())
        if if_freeze_normal_params_only:
            # if m != '12':
            #     print('This is temporary to figure out whether features are lost or just deactivated when a task is disowned')
            for param in cur_params:
                param.requires_grad = False
        model_params += cur_params
        model[m].to(device)

    if if_freeze_normal_params_only:
        with torch.no_grad():
            for conn in model['rep'].connectivities:
                # conn *= 0.75
                conn.requires_grad = True

        if True:
            for name, param in model['rep'].named_parameters():
                if 'bias' in name:
                    param.requires_grad = True

    #todo: remove freezing
    # for name, param in model['rep'].named_parameters():
    #     if 'chunk_strength' in name or 'bn_bias' in name:
    #         param.requires_grad = False

    print(f'Starting training with parameters \n \t{str(params)} \n')

    lr = params['lr']
    if overwrite_lr is not None:
        lr = overwrite_lr
    weight_decay = 0.0 if 'weight_decay' not in params else params[
        'weight_decay']
    if overwrite_weight_decay is not None:
        weight_decay = overwrite_weight_decay

    if_learn_task_specific_connections = True and 'fullconv' not in params[
        'architecture']

    lambda_reg = params['connectivities_l1']
    if_apply_l1_to_all_conn = params['connectivities_l1_all']

    if 'SGD_Adam' in params['optimizer']:
        sgd_optimizer = torch.optim.SGD([{
            'params': model_params
        }],
                                        lr=lr,
                                        momentum=0.9)

        connectivities_lr = params['connectivities_lr']
        adam_optimizer = torch.optim.AdamW([{
            'params': model['rep'].connectivities,
            'name': 'connectivities'
        }],
                                           lr=connectivities_lr,
                                           weight_decay=weight_decay)

        optimizer = SGD_Adam(sgd_optimizer, adam_optimizer)
    elif 'Adam' in params['optimizer']:
        connectivities_lr = params['connectivities_lr']
        optimizer = torch.optim.AdamW([{
            'params': model_params,
            'name': 'normal_params'
        }, {
            'params': model['rep'].connectivities,
            'lr': connectivities_lr,
            'name': 'connectivities'
        }],
                                      lr=lr,
                                      weight_decay=weight_decay)

        #TODO: only for computational graph visulaization!
        # optimizer = torch.optim.AdamW([
        #     {'params': model_params}],
        #     lr=lr, weight_decay=weight_decay)

    elif 'SGD' in params['optimizer']:
        # optimizer = torch.optim.SGD([{'params' : model_params}, {'params':model['rep'].connectivities, 'lr' : 0.2}], lr=lr, momentum=0.9)
        optimizer = torch.optim.SGD([{
            'params': model_params,
            'name': 'normal_params'
        }, {
            'params': model['rep'].connectivities,
            'name': 'connectivities'
        }],
                                    lr=lr,
                                    momentum=0.9)
    if if_continue_training:
        if 'SGD_Adam' != params['optimizer']:
            if 'SGD' != params['optimizer']:
                optimizer.load_state_dict(state['optimizer_state'])
            else:
                print('Ignoring SGD optimizer state')

    print(model['rep'])

    error_sum_min = 1.0  # highest possible error on the scale from 0 to 1 is 1

    # train2_loader_iter = iter(train2_loader)
    NUM_EPOCHS = 120
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, NUM_EPOCHS)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1)
    scheduler = None

    print(f'NUM_EPOCHS={NUM_EPOCHS}')
    n_iter = 0

    scale = {}
    for t in tasks:
        scale[t] = float(params['scales'][t])

    def write_connectivities(n_iter):
        n_conns = len(model['rep'].connectivities)
        totals = [0] * n_conns
        actives = [0] * n_conns
        for i, conn in enumerate(model['rep'].connectivities):
            totals[i] = conn.shape[0] * conn.shape[1]
            idx = conn > 0.5
            actives[i] = idx.sum().item()

        writer.add_scalar(r'active_connections', sum(actives), n_iter)
        writer.add_scalar(r'active_%%',
                          sum(actives) * 100 / float(sum(totals)), n_iter)
        for i in range(n_conns):
            writer.add_scalar(f'active_%%_{i}',
                              actives[i] * 100 / float(totals[i]), n_iter)
            writer.add_scalar(f'active_connections_{i}', actives[i], n_iter)

        if conn_counts_file == '':
            for i, cur_con in enumerate(model['rep'].connectivities):
                for j in range(cur_con.size(0)):
                    coeffs = list(cur_con[j].cpu().detach())
                    coeffs = {str(i): coeff for i, coeff in enumerate(coeffs)}
                    writer.add_scalars(f'learning_scales_{i + 1}_{j}', coeffs,
                                       n_iter)
        else:
            with open(conn_counts_file, 'a') as f:
                f.write('\n')
                f.write(f'n_iter = {n_iter}' + '\n')
                f.write('\n')
                for i, cur_con in enumerate(model['rep'].connectivities):
                    for j in range(cur_con.size(0)):
                        coeffs = cur_con[j].cpu().detach().numpy()
                        coeffs[coeffs <= 0.5] = 0
                        coeffs[coeffs > 0.5] = 1
                        f.write(
                            f'learning_scales_{i + 1}_{j}: {int(np.sum(coeffs))}/{coeffs.shape[0]}\n'
                        )
                    f.write('\n')
                f.write('\n')
                f.write('Total connections  = ' + str(sum(totals)) + '\n')
                f.write('Active connections = ' + str(sum(actives)) + '\n')
                f.write('Active % per layer = ' + str([
                    f'{(actives[i] / float(totals[i])) * 100:.0f}'
                    for i in range(n_conns)
                ]).replace("'", '') + '\n')
                f.write(
                    f'Active % =  {(sum(actives) / float(sum(totals))) * 100:.2f}'
                    + '\n')
                f.write('Active # per layer = ' + str(actives) + '\n')

    def save_model(epoch, model, optimizer):
        state = {
            'epoch': epoch + 1,
            'model_rep': model['rep'].state_dict(),
            'connectivities': model['rep'].connectivities,
            'optimizer_state': optimizer.state_dict()
        }
        if hasattr(model['rep'], 'connectivity_comeback_multipliers'):
            state['connectivity_comeback_multipliers'] = model[
                'rep'].connectivity_comeback_multipliers
        for t in tasks:
            key_name = 'model_{}'.format(t)
            state[key_name] = model[t].state_dict()
        saved_models_prefix = '/mnt/raid/data/chebykin/saved_models/{}'.format(
            time_str)
        if not os.path.exists(saved_models_prefix):
            os.makedirs(saved_models_prefix)
        save_model_path = saved_models_prefix + "/{}_{}_model.pkl".format(
            params['exp_id'], epoch + 1)
        save_model_path = re.sub(r'\s+', '_', save_model_path)
        save_model_path = re.sub(r"'", '_', save_model_path)
        save_model_path = re.sub(r'"', '_', save_model_path)
        save_model_path = re.sub(r':', '_', save_model_path)
        save_model_path = re.sub(r',', '|', save_model_path)
        if len(save_model_path) > 255:
            save_model_path = saved_models_prefix + "/{}".format(
                params['exp_id'])[:200] + "_{}_model.pkl".format(epoch + 1)
            save_model_path = re.sub(r'\s+', '_', save_model_path)
            save_model_path = re.sub(r"'", '_', save_model_path)
            save_model_path = re.sub(r'"', '_', save_model_path)
            save_model_path = re.sub(r':', '_', save_model_path)
            save_model_path = re.sub(r',', '|', save_model_path)
        torch.save(state, save_model_path)
        if epoch == 0:
            # to properly restore model, we need source code for it
            # Note: for quite some time I've been saving ordinary binmatr instead of binmatr2. Yikes!
            copy('multi_task/models/binmatr2_multi_faces_resnet.py',
                 saved_models_prefix)
            copy('multi_task/models/pspnet.py', saved_models_prefix)
            copy('multi_task/train_multi_task_binmatr.py', saved_models_prefix)

    # torch.autograd.set_detect_anomaly(True)
    for epoch in range(NUM_EPOCHS):
        if epoch == 0:
            save_model(-1, model, optimizer)  # save initialization values
        start = timer()
        print('Epoch {} Started'.format(epoch))
        # if (epoch + 1) % 50 == 0:
        #     lr_multiplier = 0.1
        #     for param_group in optimizer.param_groups:
        #         if param_group['name'] == 'connectivities':
        #             continue #don't wanna mess with connectivities
        #         param_group['lr'] *= lr_multiplier
        #         print(f"lr of {param_group['name']} was changed")
        #     print(f'Multiply sgd-only learning rate by {lr_multiplier} at step {n_iter}')
        #     # for param_group in optimizer_val.param_groups:
        #     #     param_group['lr'] *= 0.5
        # if (epoch + 1) % 15 == 0:
        #     lambda_reg *= 1.5
        #     print(f'Increased lambda_reg to {lambda_reg}')
        # if (epoch == 90):
        #     lambda_reg_backup = lambda_reg
        #     lambda_reg = 0
        # if epoch == 95:
        #     lambda_reg = lambda_reg_backup
        # if epoch == 60:
        #     print('Dividing lambda_reg by 10!!!')
        #     lambda_reg /= 10
        if 'prune' in params:
            if epoch == 30:
                prune(model, prune_ratio[1])
            if epoch == 60:
                prune(model, prune_ratio[2])

        for m in model:
            model[m].train()
            if if_freeze_normal_params_only:
                model[m].apply(set_bn_to_eval)

        for batch_idx, batch in enumerate(train_loader):
            print(n_iter)
            n_iter += 1

            # First member is always images
            images = batch[0]
            images = images.to(device)
            labels = get_relevant_labels_from_batch(batch, all_tasks, tasks,
                                                    params, device)

            loss_data = {}

            optimizer.zero_grad()
            if False:
                loss_reg = lambda_reg * torch.norm(
                    torch.cat(
                        [con.view(-1)
                         for con in model['rep'].connectivities]), 1)
            else:
                # print('Apply l1 only to task connectivities')
                # loss_reg = lambda_reg * torch.norm(model['rep'].connectivities[-1].clone().view(-1), 1)
                if if_apply_l1_to_all_conn:
                    if True:
                        loss_reg = lambda_reg * torch.norm(
                            torch.cat([
                                con.view(-1)
                                for con in model['rep'].connectivities
                            ]), 1)
                    else:
                        loss_reg = 0
                        for con in model['rep'].connectivities:
                            loss_reg += torch.norm(con, 1)
                else:
                    loss_reg = lambda_reg * torch.norm(
                        torch.cat([model['rep'].connectivities[-1].view(-1)]),
                        1)

                # if epoch == 0:
                #     print('ACHTUNG! NOW L1 is applied immediately! Only for connections-only learning!')
                if epoch < 5:
                    loss_reg *= 0
            loss = loss_reg
            loss_reg_value = loss_reg.item()
            reps = model['rep'](images)
            # del images
            for i, t in enumerate(tasks):
                if not if_learn_task_specific_connections:
                    rep = reps
                else:
                    rep = reps[i]
                out_t, _ = model[t](rep, None)
                loss_t = loss_fn[t](out_t, labels[t])
                loss_data[t] = scale[t] * loss_t.item()
                loss = loss + scale[t] * loss_t
            loss.backward()
            # plot_grad_flow(model['rep'].named_parameters())
            optimizer.step()
            # scheduler.step()

            writer.add_scalar('training_loss', loss.item(), n_iter)
            writer.add_scalar('l1_reg_loss', loss_reg_value, n_iter)
            writer.add_scalar('training_minus_l1_reg_loss',
                              loss.item() - loss_reg_value, n_iter)
            for t in tasks:
                writer.add_scalar('training_loss_{}'.format(t), loss_data[t],
                                  n_iter)

            if n_iter == 1:
                #need to do it after the first forward pass because that normalizes them to the [0, 1] range
                write_connectivities(1)
                # for visualizing computation graph:
                # model['rep'].eval()
                # writer.add_graph(model['rep'], images[0][None, :, :, :])
                # model['rep'].train()
        if scheduler is not None:
            scheduler.step()

        for m in model:
            model[m].eval()

        tot_loss = {}
        tot_loss['l1_reg'] = lambda_reg * torch.norm(
            torch.cat([con.view(-1)
                       for con in model['rep'].connectivities]), 1)
        tot_loss['all'] = tot_loss['l1_reg']  #0.0
        for t in tasks:
            tot_loss[t] = 0.0
        num_val_batches = 0
        with torch.no_grad():
            for batch_val in val_loader:
                val_images = batch_val[0].to(device)
                labels_val = get_relevant_labels_from_batch(
                    batch_val, all_tasks, tasks, params, device)
                # labels_val = {}
                #
                # for i, t in enumerate(all_tasks):
                #     if t not in tasks:
                #         continue
                #     labels_val[t] = batch_val[i + 1]
                #     labels_val[t] = labels_val[t].to(device)

                val_reps = model['rep'](val_images)
                for i, t in enumerate(tasks):
                    if not if_learn_task_specific_connections:
                        val_rep = val_reps
                    else:
                        val_rep = val_reps[i]
                    out_t_val, _ = model[t](val_rep, None)
                    loss_t = loss_fn[t](out_t_val, labels_val[t])
                    # tot_loss['all'] += loss_t.item()
                    #todo: I think old way of calculating validation loss was wrong, because we also divided l1 loss by the number of tasks
                    tot_loss['all'] += scale[t] * loss_t.item()
                    tot_loss[t] += scale[t] * loss_t.item()
                    metric[t].update(out_t_val, labels_val[t])
                num_val_batches += 1

        error_sums = defaultdict(lambda: 0)
        for t in tasks:
            if False:
                writer.add_scalar('validation_loss_{}'.format(t),
                                  tot_loss[t] / num_val_batches, n_iter)
            metric_results = metric[t].get_result()
            for metric_key in metric_results:
                if metric_key == 'acc_blncd':
                    if if_scale_by_blncd_acc:
                        blncd_accs_val[t] = metric_results[metric_key]
                writer.add_scalar('metric_{}_{}'.format(metric_key, t),
                                  metric_results[metric_key], n_iter)
                error_sums[metric_key] += 1 - metric_results[metric_key]
            metric[t].reset()

        for metric_key in metric_results:
            error_sum = error_sums[metric_key]
            error_sum /= float(len(tasks))
            writer.add_scalar(f'average_error_{metric_key}', error_sum * 100,
                              n_iter)
            print(f'average_error_{metric_key} = {error_sum * 100}')

        # writer.add_scalar('validation_loss', tot_loss['all'] / num_val_batches / len(tasks), n_iter)
        # todo: I think old way of calculating validation loss was wrong, because we also divided l1 loss by the number of tasks
        writer.add_scalar('validation_loss', tot_loss['all'] / num_val_batches,
                          n_iter)
        writer.add_scalar('validation_loss_minus_l1_reg_loss',
                          (tot_loss['all'] - tot_loss['l1_reg']) /
                          num_val_batches, n_iter)
        # writer.add_scalar('l1_reg_loss', tot_loss['l1_reg'] / num_val_batches, n_iter)

        # write scales to log
        if not if_train_default_resnet:
            write_connectivities(n_iter)

        if epoch % 3 == 0 or (error_sum < error_sum_min
                              and epoch >= 3) or (epoch == NUM_EPOCHS - 1):
            # Save after every 3 epoch
            save_model(epoch, model, optimizer)

        error_sum_min = min(error_sum, error_sum_min)
        writer.flush()

        end = timer()
        print('Epoch ended in {}s'.format(end - start))

    writer.close()
Пример #17
0
def main(data_path='/data/SN6_buildings/train/AOI_11_Rotterdam/',
         config_path='/project/configs/senet154_gcc_fold1.py',
         gpu='0'):

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    config = get_config(config_path)
    model_name = config['model_name']
    fold_number = config['fold_number']
    alias = config['alias']
    log_path = osp.join(config['logs_path'],
                        alias + str(fold_number) + '_' + model_name)

    device = torch.device(config['device'])
    weights = config['weights']
    loss_name = config['loss']
    optimizer_name = config['optimizer']
    lr = config['lr']
    decay = config['decay']
    momentum = config['momentum']
    epochs = config['epochs']
    fp16 = config['fp16']
    n_classes = config['n_classes']
    input_channels = config['input_channels']
    main_metric = config['main_metric']

    best_models_count = config['best_models_count']
    minimize_metric = config['minimize_metric']
    min_delta = config['min_delta']

    train_images = data_path
    data_type = config['data_type']
    masks_data_path = config['masks_data_path']
    folds_file = config['folds_file']
    train_augs = config['train_augs']
    preprocessing_fn = config['preprocessing_fn']
    limit_files = config['limit_files']
    batch_size = config['batch_size']
    shuffle = config['shuffle']
    num_workers = config['num_workers']
    valid_augs = config['valid_augs']
    val_batch_size = config['val_batch_size']
    multiplier = config['multiplier']

    train_dataset = SemSegDataset(images_dir=train_images,
                                  data_type=data_type,
                                  masks_dir=masks_data_path,
                                  mode='train',
                                  n_classes=n_classes,
                                  folds_file=folds_file,
                                  fold_number=fold_number,
                                  augmentation=train_augs,
                                  preprocessing=preprocessing_fn,
                                  limit_files=limit_files,
                                  multiplier=multiplier)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=shuffle,
                              num_workers=num_workers)

    valid_dataset = SemSegDataset(images_dir=train_images,
                                  data_type=data_type,
                                  mode='valid',
                                  folds_file=folds_file,
                                  n_classes=n_classes,
                                  fold_number=fold_number,
                                  augmentation=valid_augs,
                                  preprocessing=preprocessing_fn,
                                  limit_files=limit_files)

    valid_loader = DataLoader(dataset=valid_dataset,
                              batch_size=val_batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    model = make_model(model_name=model_name,
                       weights=weights,
                       n_classes=n_classes,
                       input_channels=input_channels).to(device)

    loss = get_loss(loss_name=loss_name)
    optimizer = get_optimizer(optimizer_name=optimizer_name,
                              model=model,
                              lr=lr,
                              momentum=momentum,
                              decay=decay)

    if config['scheduler'] == 'reduce_on_plateau':
        print('reduce lr')
        alpha = config['alpha']
        patience = config['patience']
        threshold = config['thershold']
        min_lr = config['min_lr']
        mode = config['scheduler_mode']
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer,
            factor=alpha,
            verbose=True,
            patience=patience,
            mode=mode,
            threshold=threshold,
            min_lr=min_lr)
    elif config['scheduler'] == 'steps':
        print('steps lr')
        steps = config['steps']
        step_gamma = config['step_gamma']
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                         milestones=steps,
                                                         gamma=step_gamma)
    else:
        scheduler = None

    callbacks = []

    dice_callback = DiceCallback()
    callbacks.append(dice_callback)
    callbacks.append(CheckpointCallback(save_n_best=best_models_count))
    callbacks.append(
        EarlyStoppingCallback(patience=config['early_stopping'],
                              metric=main_metric,
                              minimize=minimize_metric,
                              min_delta=min_delta))

    runner = SupervisedRunner(device=device)
    loaders = {'train': train_loader, 'valid': valid_loader}

    runner.train(model=model,
                 criterion=loss,
                 optimizer=optimizer,
                 loaders=loaders,
                 scheduler=scheduler,
                 callbacks=callbacks,
                 logdir=log_path,
                 num_epochs=epochs,
                 verbose=True,
                 main_metric=main_metric,
                 minimize_metric=minimize_metric,
                 fp16=fp16)
Пример #18
0
def main(args):
    try:
        world_size = int(os.environ['WORLD_SIZE'])
        rank = int(os.environ['RANK'])
        dist_url = "tcp://{}:{}".format(os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"])
    except KeyError:
        world_size = 1
        rank = 0
        dist_url = "tcp://127.0.0.1:12584"

    dist.init_process_group(backend='nccl', init_method=dist_url, rank=rank, world_size=world_size)
    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)

    if not os.path.exists(cfg.output) and rank is 0:
        os.makedirs(cfg.output)
    else:
        time.sleep(2)

    log_root = logging.getLogger()
    init_logging(log_root, rank, cfg.output)
    train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set, shuffle=True)
    train_loader = DataLoaderX(
        local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
        sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)

    dropout = 0.4 if cfg.dataset == "webface" else 0
    backbone = get_model(args.network, dropout=dropout, fp16=cfg.fp16).to(local_rank)
    backbone_onnx = get_model(args.network, dropout=dropout, fp16=False)

    if args.resume:
        try:
            backbone_pth = os.path.join(cfg.output, "backbone.pth")
            backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
            if rank is 0:
                logging.info("backbone resume successfully!")
        except (FileNotFoundError, KeyError, IndexError, RuntimeError):
            logging.info("resume fail, backbone init successfully!")

    for ps in backbone.parameters():
        dist.broadcast(ps, 0)
    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone, broadcast_buffers=False, device_ids=[local_rank])
    backbone.train()

    margin_softmax = losses.get_loss(args.loss)
    module_partial_fc = PartialFC(
        rank=rank, local_rank=local_rank, world_size=world_size, resume=args.resume,
        batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
        sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)

    opt_backbone = torch.optim.SGD(
        params=[{'params': backbone.parameters()}],
        lr=cfg.lr / 512 * cfg.batch_size * world_size,
        momentum=0.9, weight_decay=cfg.weight_decay)
    opt_pfc = torch.optim.SGD(
        params=[{'params': module_partial_fc.parameters()}],
        lr=cfg.lr / 512 * cfg.batch_size * world_size,
        momentum=0.9, weight_decay=cfg.weight_decay)

    scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
        optimizer=opt_backbone, lr_lambda=cfg.lr_func)
    scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
        optimizer=opt_pfc, lr_lambda=cfg.lr_func)

    start_epoch = 0
    total_step = int(len(train_set) / cfg.batch_size / world_size * cfg.num_epoch)
    if rank is 0: logging.info("Total Step is: %d" % total_step)

    callback_verification = CallBackVerification(2000, rank, cfg.val_targets, cfg.rec)
    callback_logging = CallBackLogging(50, rank, total_step, cfg.batch_size, world_size, None)
    callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)

    loss = AverageMeter()
    global_step = 0
    grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
    for epoch in range(start_epoch, cfg.num_epoch):
        train_sampler.set_epoch(epoch)
        for step, (img, label) in enumerate(train_loader):
            global_step += 1
            features = F.normalize(backbone(img))
            x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
            if cfg.fp16:
                features.backward(grad_amp.scale(x_grad))
                grad_amp.unscale_(opt_backbone)
                clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
                grad_amp.step(opt_backbone)
                grad_amp.update()
            else:
                features.backward(x_grad)
                clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
                opt_backbone.step()

            opt_pfc.step()
            module_partial_fc.update()
            opt_backbone.zero_grad()
            opt_pfc.zero_grad()
            loss.update(loss_v, 1)
            callback_logging(global_step, loss, epoch, cfg.fp16, grad_amp)
            callback_verification(global_step, backbone)
        callback_checkpoint(global_step, backbone, module_partial_fc, backbone_onnx)
        scheduler_backbone.step()
        scheduler_pfc.step()
    dist.destroy_process_group()
Пример #19
0
    image, label, (args.height, args.width), cs_19)

processed_train = dataset_train.map(get_images_processed)
processed_train = processed_train.map(augmentor)
processed_val = dataset_validation.map(get_images_processed)
processed_train = processed_train.shuffle(args.shuffle_buffer).batch(
    BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
processed_val = processed_val.shuffle(args.shuffle_buffer).batch(BATCH_SIZE, drop_remainder=True) \
    if (dataset_validation is not None) else None
processed_train = mirrored_strategy.experimental_distribute_dataset(
    processed_train)
processed_val = mirrored_strategy.experimental_distribute_dataset(
    processed_val)

if gan_mode == "hinge":
    gan_loss_obj = get_loss(name="Wasserstein")
elif gan_mode == "wgan_gp":
    gan_loss_obj = get_loss(name="Wasserstein")
elif gan_mode == "ls_gan":
    gan_loss_obj = get_loss(name="MSE")
else:
    gan_loss_obj = get_loss(name="binary_crossentropy")
kl_loss = lambda mean, logvar: 0.5 * tf.reduce_sum(
    tf.square(mean) + tf.exp(logvar) - 1 - logvar)
feature_loss = get_loss(name="FeatureLoss")

# TODO: Add Regularization loss


def discriminator_loss(real_list, generated_list):
    total_disc_loss = 0
Пример #20
0
    def initialization(self):

        SEED = self.config.data.random_seed
        if SEED != 999:
            print("random_seed is ", SEED)
            torch.manual_seed(SEED)
            torch.cuda.manual_seed(SEED)
            np.random.seed(SEED)
        else:
            print(" no random seeds!")

        self.dataset, val_dataset, self.label = get_initial(
            self.config, train=True)  # return dataset instance
        self.label_encoder = LabelEncoder()
        self.label_encoder.fit(self.label)
        torch.cuda.empty_cache()

        self.model = get_model(self.config)
        self.optimizer = get_optimizer(self.config, self.model.parameters())
        checkpoint = get_initial_checkpoint(self.config)

        if torch.cuda.device_count() > 1:
            self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.cuda()

        if self.config.loss.name == "softmax_center":
            print("Add center loss !!")

            self.center_model = get_center_loss(
                class_num=self.config.data.pid_num,
                feature_num=512,
                use_gpu=True)
            self.optimizer_center = get_center_optimizer(
                self.center_model.parameters(),
                self.config.optimizer.params.lr)

        if checkpoint is not None:
            self.last_epoch, self.step = load_checkpoint(
                self.model, self.optimizer, self.center_model,
                self.optimizer_center, checkpoint)
        print("from checkpoint {} last epoch: {}".format(
            checkpoint, self.last_epoch))

        self.sampler = get_sampler(self.dataset, self.config)
        self.loss_function = get_loss(self.config)

        self.collate_fn = get_collate_fn(self.config,
                                         self.config.data.frame_num,
                                         self.sample_type)  #

        if self.sampler is not None:
            self.data_loader = DataLoader(
                dataset=self.dataset,
                batch_sampler=self.sampler,
                collate_fn=self.collate_fn,
                num_workers=self.num_workers,
            )
        else:
            self.data_loader = DataLoader(
                dataset=self.dataset,
                batch_size=self.config.train.batch_size.batch1,
                collate_fn=self.collate_fn,
                num_workers=self.num_workers,
                drop_last=self.config.data.drop_last,
                shuffle=self.config.data.pid_shuffle,
            )
def train_cgd(epoch_num=10,
              optim_type='ACGD',
              startPoint=None,
              start_n=0,
              z_dim=128,
              batchsize=64,
              tols={
                  'tol': 1e-10,
                  'atol': 1e-16
              },
              l2_penalty=0.0,
              momentum=0.0,
              loss_name='WGAN',
              model_name='dc',
              model_config=None,
              data_path='None',
              show_iter=100,
              logdir='test',
              dataname='CIFAR10',
              device='cpu',
              gpu_num=1,
              ada_train=True,
              log=False,
              collect_info=False,
              args=None):
    lr_d = args['lr_d']
    lr_g = args['lr_g']
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)
    if optim_type == 'BCGD':
        optimizer = BCGD(max_params=G.parameters(),
                         min_params=D.parameters(),
                         lr_max=lr_g,
                         lr_min=lr_d,
                         momentum=momentum,
                         tol=tols['tol'],
                         atol=tols['atol'],
                         device=device)
        # scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    elif optim_type == 'ICR':
        optimizer = ICR(max_params=G.parameters(),
                        min_params=D.parameters(),
                        lr=lr_d,
                        alpha=1.0,
                        device=device)
        # scheduler = icrScheduler(optimizer, milestone)
    elif optim_type == 'ACGD':
        optimizer = ACGD(max_params=G.parameters(),
                         min_params=D.parameters(),
                         lr_max=lr_g,
                         lr_min=lr_d,
                         tol=tols['tol'],
                         atol=tols['atol'],
                         device=device,
                         solver='cg')
        # scheduler = lr_scheduler(optimizer=optimizer, milestone=milestone)
    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        # optimizer.load_state_dict(chk['optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)

    mod = 10
    accs = torch.tensor([0.8 for _ in range(mod)])

    for e in range(epoch_num):
        # scheduler.step(epoch=e)
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name,
                            g_loss=False,
                            d_real=d_real,
                            d_fake=d_fake,
                            l2_weight=l2_penalty,
                            D=D)
            optimizer.zero_grad()
            optimizer.step(loss)

            num_correct = torch.sum(d_real > 0) + torch.sum(d_fake < 0)
            acc = num_correct.item() / (d_real.shape[0] + d_fake.shape[0])
            accs[count % mod] = acc
            acc_indicator = sum(accs) / mod
            if acc_indicator > 0.9:
                ada_ratio = 0.05
            elif acc_indicator < 0.80:
                ada_ratio = 0.1
            else:
                ada_ratio = 1.0
            if ada_train:
                optimizer.set_lr(lr_max=lr_g, lr_min=ada_ratio * lr_d)

            if count % show_iter == 0 and count != 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs' %
                      (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(path=logdir,
                                name='%s-%s_%d.pth' %
                                (optim_type, model_name, count + start_n),
                                D=D,
                                G=G,
                                optimizer=optimizer)
            if wandb and log:
                wandb.log(
                    {
                        'Real score': d_real.mean().item(),
                        'Fake score': d_fake.mean().item(),
                        'Loss': loss.item(),
                        'Acc_indicator': acc_indicator,
                        'Ada ratio': ada_ratio
                    },
                    step=count,
                )

            if collect_info and wandb:
                cgd_info = optimizer.get_info()
                wandb.log(
                    {
                        'CG iter num': cgd_info['iter_num'],
                        'CG runtime': cgd_info['time'],
                        'D gradient': cgd_info['grad_y'],
                        'G gradient': cgd_info['grad_x'],
                        'D hvp': cgd_info['hvp_y'],
                        'G hvp': cgd_info['hvp_x'],
                        'D cg': cgd_info['cg_y'],
                        'G cg': cgd_info['cg_x']
                    },
                    step=count)
            count += 1
Пример #22
0
def run() -> float:
    np.random.seed(0)
    model_dir = config.experiment_dir

    logger.info('=' * 50)
    # logger.info(f'hyperparameters: {params}')

    train_loader, val_loader, test_loader, label_encoder = load_data(args.fold)
    model = create_model()

    optimizer = get_optimizer(config, model.parameters())
    lr_scheduler = get_scheduler(config, optimizer)
    lr_scheduler2 = get_scheduler(
        config, optimizer) if config.scheduler2.name else None
    criterion = get_loss(config)

    if args.weights is None:
        last_epoch = 0
        logger.info(f'training will start from epoch {last_epoch+1}')
    else:
        last_checkpoint = torch.load(args.weights)
        assert last_checkpoint['arch'] == config.model.arch
        model.load_state_dict(last_checkpoint['state_dict'])
        optimizer.load_state_dict(last_checkpoint['optimizer'])
        logger.info(f'checkpoint {args.weights} was loaded.')

        last_epoch = last_checkpoint['epoch']
        logger.info(f'loaded the model from epoch {last_epoch}')

        if args.lr_override != 0:
            set_lr(optimizer, float(args.lr_override))
        elif 'lr' in config.scheduler.params:
            set_lr(optimizer, config.scheduler.params.lr)

    if args.gen_predict:
        print('inference mode')
        generate_submission(val_loader, test_loader, model, label_encoder,
                            last_epoch, args.weights)
        sys.exit(0)

    if args.gen_features:
        print('inference mode')
        generate_features(test_loader, model, args.weights)
        sys.exit(0)

    best_score = 0.0
    best_epoch = 0

    last_lr = get_lr(optimizer)
    best_model_path = args.weights

    for epoch in range(last_epoch + 1, config.train.num_epochs + 1):
        logger.info('-' * 50)

        # if not is_scheduler_continuous(config.scheduler.name):
        #     # if we have just reduced LR, reload the best saved model
        #     lr = get_lr(optimizer)
        #     logger.info(f'learning rate {lr}')
        #
        #     if lr < last_lr - 1e-10 and best_model_path is not None:
        #         last_checkpoint = torch.load(os.path.join(model_dir, best_model_path))
        #         assert(last_checkpoint['arch']==config.model.arch)
        #         model.load_state_dict(last_checkpoint['state_dict'])
        #         optimizer.load_state_dict(last_checkpoint['optimizer'])
        #         logger.info(f'checkpoint {best_model_path} was loaded.')
        #         set_lr(optimizer, lr)
        #         last_lr = lr
        #
        #     if lr < config.train.min_lr * 1.01:
        #         logger.info('reached minimum LR, stopping')
        #         break

        get_lr(optimizer)

        train(train_loader, model, criterion, optimizer, epoch, lr_scheduler,
              lr_scheduler2)
        score = validate(val_loader, model, epoch)

        if not is_scheduler_continuous(config.scheduler.name):
            lr_scheduler.step(score)
        if lr_scheduler2 and not is_scheduler_continuous(
                config.scheduler.name):
            lr_scheduler2.step(score)

        is_best = score > best_score
        best_score = max(score, best_score)
        if is_best:
            best_epoch = epoch

        data_to_save = {
            'epoch': epoch,
            'arch': config.model.arch,
            'state_dict': model.state_dict(),
            'best_score': best_score,
            'score': score,
            'optimizer': optimizer.state_dict(),
            'options': config
        }

        filename = config.version
        if is_best:
            best_model_path = f'{filename}_f{args.fold}_e{epoch:02d}_{score:.04f}.pth'
            save_checkpoint(data_to_save, best_model_path, model_dir)

    logger.info(f'best score: {best_score:.04f}')
    return -best_score
Пример #23
0
augmentor = lambda batch: augment_autoencoder(
    batch, size=(IMG_HEIGHT, IMG_WIDTH), crop=(CROP_HEIGHT, CROP_WIDTH))
train_A = train_A.map(
    augmentor, num_parallel_calls=tf.data.AUTOTUNE).shuffle(BUFFER_SIZE).batch(
        BATCH_SIZE,
        drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

train_B = train_B.map(
    augmentor, num_parallel_calls=tf.data.AUTOTUNE).shuffle(BUFFER_SIZE).batch(
        BATCH_SIZE,
        drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
train_A = mirrored_strategy.experimental_distribute_dataset(train_A)
train_B = mirrored_strategy.experimental_distribute_dataset(train_B)

if gan_mode == "wgan_gp":
    gan_loss_obj = get_loss(name="Wasserstein")
elif gan_mode == "ls_gan":
    gan_loss_obj = get_loss(name="MSE")
else:
    gan_loss_obj = get_loss(name="binary_crossentropy")
patch_nce_loss = get_loss(name="PatchNCELoss")
id_loss_obj = get_loss(name="MSE")


def discriminator_loss(real, generated):
    if gan_mode == "wgan_gp":
        real_loss = gan_loss_obj(-tf.ones_like(real), real)
        generated_loss = gan_loss_obj(tf.ones_like(generated), generated)
    else:
        real_loss = gan_loss_obj(tf.ones_like(real), real)
        generated_loss = gan_loss_obj(tf.zeros_like(generated), generated)
Пример #24
0
def train_d(epoch_num=10,
            logdir='test',
            optim='SGD',
            loss_name='JSD',
            show_iter=500,
            model_weight=None,
            load_d=False,
            load_g=False,
            compare_path=None,
            info_time=100,
            run_select=None,
            device='cpu'):
    lr_d = 0.001
    lr_g = 0.01
    batchsize = 128
    z_dim = 96
    print('discriminator lr: %.3f' % lr_d)
    dataset = get_data(dataname='MNIST', path='../datas/mnist')
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D = dc_D().to(device)
    G = dc_G(z_dim=z_dim).to(device)
    D.apply(weights_init_d)
    G.apply(weights_init_g)
    if model_weight is not None:
        chk = torch.load(model_weight)
        if load_d:
            D.load_state_dict(chk['D'])
            print('Load D from %s' % model_weight)
        if load_g:
            G.load_state_dict(chk['G'])
            print('Load G from %s' % model_weight)
    if compare_path is not None:
        discriminator = dc_D().to(device)
        model_weight = torch.load(compare_path)
        discriminator.load_state_dict(model_weight['D'])
        model_vec = torch.cat(
            [p.contiguous().view(-1) for p in discriminator.parameters()])
        print('Load discriminator from %s' % compare_path)
    if run_select is not None:
        fixed_data = torch.load(run_select)
        real_set = fixed_data['real_set']
        fake_set = fixed_data['fake_set']
        real_d = fixed_data['real_d']
        fake_d = fixed_data['fake_d']
        fixed_vec = fixed_data['pred_vec']
        print('load fixed data set')
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    writer = SummaryWriter(log_dir='logs/%s/%s_%.3f' %
                           (logdir, current_time, lr_d))
    if optim == 'SGD':
        d_optimizer = SGD(D.parameters(), lr=lr_d)
        print('Optimizer SGD')
    else:
        d_optimizer = BCGD2(max_params=G.parameters(),
                            min_params=D.parameters(),
                            lr_max=lr_g,
                            lr_min=lr_d,
                            update_max=False,
                            device=device,
                            collect_info=True)
        print('Optimizer BCGD2')
    timer = time.time()
    count = 0
    d_losses = []
    g_losses = []
    for e in range(epoch_num):
        tol_correct = 0
        tol_dloss = 0
        tol_gloss = 0
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            z = torch.randn((real_x.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            D_loss = get_loss(name=loss_name,
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake)
            tol_dloss += D_loss.item() * real_x.shape[0]
            G_loss = get_loss(name=loss_name,
                              g_loss=True,
                              d_real=d_real,
                              d_fake=d_fake)
            tol_gloss += G_loss.item() * fake_x.shape[0]
            if compare_path is not None and count % info_time == 0:
                diff = get_diff(net=D, model_vec=model_vec)
                writer.add_scalar('Distance from checkpoint',
                                  diff.item(),
                                  global_step=count)
                if run_select is not None:
                    with torch.no_grad():
                        d_real_set = D(real_set)
                        d_fake_set = D(fake_set)
                        diff_real = torch.norm(d_real_set - real_d, p=2)
                        diff_fake = torch.norm(d_fake_set - fake_d, p=2)
                        d_vec = torch.cat([d_real_set, d_fake_set])
                        diff = torch.norm(d_vec.sub_(fixed_vec), p=2)
                        writer.add_scalars('L2 norm of pred difference', {
                            'Total': diff.item(),
                            'real set': diff_real.item(),
                            'fake set': diff_fake.item()
                        },
                                           global_step=count)
            d_optimizer.zero_grad()
            if optim == 'SGD':
                D_loss.backward()
                d_optimizer.step()
                gd = torch.norm(torch.cat(
                    [p.grad.contiguous().view(-1) for p in D.parameters()]),
                                p=2)
                gg = torch.norm(torch.cat(
                    [p.grad.contiguous().view(-1) for p in G.parameters()]),
                                p=2)
            else:
                d_optimizer.step(D_loss)
                cgdInfo = d_optimizer.get_info()
                gd = cgdInfo['grad_y']
                gg = cgdInfo['grad_x']
                writer.add_scalars('Grad', {'update': cgdInfo['update']},
                                   global_step=count)
            tol_correct += (d_real > 0).sum().item() + (d_fake <
                                                        0).sum().item()
            writer.add_scalars('Loss', {
                'D_loss': D_loss.item(),
                'G_loss': G_loss.item()
            },
                               global_step=count)
            writer.add_scalars('Grad', {
                'D grad': gd,
                'G grad': gg
            },
                               global_step=count)
            writer.add_scalars('Discriminator output', {
                'Generated image': d_fake.mean().item(),
                'Real image': d_real.mean().item()
            },
                               global_step=count)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , D_loss: %.5f, G_loss: %.5f, time: %.3fs' %
                      (count, D_loss.item(), G_loss.item(), time_cost))
                timer = time.time()
                save_checkpoint(path=logdir,
                                name='FixG-%.3f_%d.pth' % (lr_d, count),
                                D=D,
                                G=G)
            count += 1
    writer.close()
Пример #25
0
def main():
    ## dynamically adjust hyper-parameters for ResNets according to base_width
    if args.base_width != 64 and 'sat' in args.loss:
        factor = 64. / args.base_width
        args.sat_alpha = args.sat_alpha**(1. / factor)
        args.sat_es = int(args.sat_es * factor)
        print(
            "Adaptive parameters adjustment: alpha = {:.3f}, Es = {:d}".format(
                args.sat_alpha, args.sat_es))

    print(args)
    global best_prec1, best_auc

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
        os.makedirs(os.path.join(args.save_dir, 'train'))
        os.makedirs(os.path.join(args.save_dir, 'val'))
        os.makedirs(os.path.join(args.save_dir, 'test'))

    # prepare dataset
    train_loader, val_loaders, test_loader, num_classes, image_datasets = get_loader(
        args)

    model = get_model(args, num_classes, base_width=args.base_width)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.dataset == 'nexperia_merge':
                best_auc = checkpoint['best_auc']
            else:
                best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    criterion = get_loss(args,
                         num_classes=num_classes,
                         datasets=image_datasets)
    optimizer = get_optimizer(model, args)
    scheduler = get_scheduler(optimizer, args)

    train_timeline = Timeline()
    val_timeline = Timeline()
    test_timeline = Timeline()

    if args.evaluate:
        validate(test_loader, model, args.crop)
        return

    print("*" * 40)
    start = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, num_classes,
              train_timeline, args.sat_es, args.dataset, args.mod, args.crop)
        print("*" * 40)

        # evaluate on validation sets
        prec1 = 0
        print('val:')
        val_auc = validate(val_loaders,
                           model,
                           epoch,
                           num_classes,
                           val_timeline,
                           args.dataset,
                           state='val',
                           criterion=criterion,
                           crop=args.crop)
        print("*" * 40)

        print('test:')
        test_auc = validate(test_loader,
                            model,
                            epoch,
                            num_classes,
                            test_timeline,
                            args.dataset,
                            state='test',
                            criterion=criterion,
                            crop=args.crop)
        print("*" * 40)

        # remember best auc and save checkpoint
        is_best = val_auc > best_auc
        best_auc = max(val_auc, best_auc)
        if args.save_freq > 0 and (epoch + 1) % args.save_freq == 0:
            filename = 'checkpoint_{}.tar'.format(epoch + 1)
        else:
            filename = None
        save_checkpoint(args.save_dir, {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_auc': best_auc,
        },
                        is_best,
                        filename=filename)

    # evaludate latest checkpoint
    print("Test acc of latest checkpoint:", end='\t')
    validate(test_loader,
             model,
             epoch,
             num_classes,
             test_timeline,
             args.dataset,
             last=True,
             crop=args.crop)
    print("*" * 40)

    # evaluate best checkpoint
    checkpoint = torch.load(os.path.join(args.save_dir, 'checkpoint_best.tar'))
    print("Best validation auc ({}th epoch): {:.2f}%".format(
        checkpoint['epoch'], best_auc * 100.))
    model.load_state_dict(checkpoint['state_dict'])
    print("Test acc of best checkpoint:", end='\t')
    validate(test_loader,
             model,
             checkpoint['epoch'],
             num_classes,
             test_timeline,
             args.dataset,
             last=True,
             crop=args.crop)
    print("*" * 40)

    time_elapsed = time.time() - start
    print('It takes {:.0f}m {:.0f}s to train.'.format(time_elapsed // 60,
                                                      time_elapsed % 60))

    # save best result
    filename = 'train_results.tar'
    save_checkpoint(args.save_dir, {
        'num_epochs': args.epochs,
        'state_dict': model.state_dict(),
    },
                    is_best=True,
                    filename=filename)

    # save soft label
    if hasattr(criterion, 'soft_labels'):
        out_fname = os.path.join(args.save_dir, 'updated_soft_labels.npy')
        np.save(out_fname, criterion.soft_labels.cpu().numpy())
        print("Updated soft labels is saved to {}".format(out_fname))

    # save timelines
    train_acc_class = torch.cat(train_timeline.acc_class, dim=0)
    train_loss_class = torch.cat(train_timeline.loss_class, dim=0)
    train_acc_bi_class = torch.cat(train_timeline.acc_bi_class, dim=0)
    train_loss_bi_class = torch.cat(train_timeline.loss_bi_class, dim=0)
    train_me_class = torch.cat(train_timeline.me_class, dim=0)
    train_me_bi_class = torch.cat(train_timeline.me_bi_class, dim=0)

    val_acc_class = torch.cat(val_timeline.acc_class, dim=0)
    val_loss_class = torch.cat(val_timeline.loss_class, dim=0)
    val_acc_bi_class = torch.cat(val_timeline.acc_bi_class, dim=0)
    val_loss_bi_class = torch.cat(val_timeline.loss_bi_class, dim=0)
    val_me_class = torch.cat(val_timeline.me_class, dim=0)
    val_me_bi_class = torch.cat(val_timeline.me_bi_class, dim=0)

    test_acc_class = torch.cat(test_timeline.acc_class, dim=0)
    test_loss_class = torch.cat(test_timeline.loss_class, dim=0)
    test_acc_bi_class = torch.cat(test_timeline.acc_bi_class, dim=0)
    test_loss_bi_class = torch.cat(test_timeline.loss_bi_class, dim=0)
    test_me_class = torch.cat(test_timeline.me_class, dim=0)
    test_me_bi_class = torch.cat(test_timeline.me_bi_class, dim=0)

    np.save(os.path.join(args.save_dir, 'train', 'loss.npy'),
            train_timeline.loss)
    np.save(os.path.join(args.save_dir, 'train', 'acc.npy'),
            train_timeline.acc)
    np.save(os.path.join(args.save_dir, 'train', 'loss_bi.npy'),
            train_timeline.loss_bi)
    np.save(os.path.join(args.save_dir, 'train', 'acc_bi.npy'),
            train_timeline.acc_bi)
    np.save(os.path.join(args.save_dir, 'train', 'loss_class.npy'),
            train_loss_class)
    np.save(os.path.join(args.save_dir, 'train', 'acc_class.npy'),
            train_acc_class)
    np.save(os.path.join(args.save_dir, 'train', 'loss_bi_class.npy'),
            train_loss_bi_class)
    np.save(os.path.join(args.save_dir, 'train', 'acc_bi_class.npy'),
            train_acc_bi_class)
    np.save(os.path.join(args.save_dir, 'train', 'margin_error.npy'),
            train_timeline.margin_error)
    np.save(os.path.join(args.save_dir, 'train', 'margin_error_bi.npy'),
            train_timeline.margin_error_bi)
    np.save(os.path.join(args.save_dir, 'train', 'margin_error_class.npy'),
            train_me_class)
    np.save(os.path.join(args.save_dir, 'train', 'margin_error_bi_class.npy'),
            train_me_bi_class)
    np.save(os.path.join(args.save_dir, 'train', 'auc.npy'),
            train_timeline.auc)
    np.save(os.path.join(args.save_dir, 'train', 'fpr_991.npy'),
            train_timeline.fpr_991)
    np.save(os.path.join(args.save_dir, 'train', 'fpr_993.npy'),
            train_timeline.fpr_993)
    np.save(os.path.join(args.save_dir, 'train', 'fpr_995.npy'),
            train_timeline.fpr_995)
    np.save(os.path.join(args.save_dir, 'train', 'fpr_997.npy'),
            train_timeline.fpr_997)
    np.save(os.path.join(args.save_dir, 'train', 'fpr_999.npy'),
            train_timeline.fpr_999)
    np.save(os.path.join(args.save_dir, 'train', 'fpr_1.npy'),
            train_timeline.fpr_1)
    print("other training details are saved to {}".format(
        os.path.join(args.save_dir, 'train')))

    np.save(os.path.join(args.save_dir, 'val', 'loss.npy'), val_timeline.loss)
    np.save(os.path.join(args.save_dir, 'val', 'acc.npy'), val_timeline.acc)
    np.save(os.path.join(args.save_dir, 'val', 'loss_bi.npy'),
            val_timeline.loss_bi)
    np.save(os.path.join(args.save_dir, 'val', 'acc_bi.npy'),
            val_timeline.acc_bi)
    np.save(os.path.join(args.save_dir, 'val', 'loss_class.npy'),
            val_loss_class)
    np.save(os.path.join(args.save_dir, 'val', 'acc_class.npy'), val_acc_class)
    np.save(os.path.join(args.save_dir, 'val', 'loss_bi_class.npy'),
            val_loss_bi_class)
    np.save(os.path.join(args.save_dir, 'val', 'acc_bi_class.npy'),
            val_acc_bi_class)
    np.save(os.path.join(args.save_dir, 'val', 'margin_error.npy'),
            val_timeline.margin_error_bi)
    np.save(os.path.join(args.save_dir, 'val', 'margin_error_bi.npy'),
            val_timeline.margin_error_bi)
    np.save(os.path.join(args.save_dir, 'val', 'margin_error_class.npy'),
            val_me_class)
    np.save(os.path.join(args.save_dir, 'val', 'margin_error_bi_class.npy'),
            val_me_bi_class)
    np.save(os.path.join(args.save_dir, 'val', 'auc.npy'), val_timeline.auc)
    np.save(os.path.join(args.save_dir, 'val', 'fpr_991.npy'),
            val_timeline.fpr_991)
    np.save(os.path.join(args.save_dir, 'val', 'fpr_993.npy'),
            val_timeline.fpr_993)
    np.save(os.path.join(args.save_dir, 'val', 'fpr_995.npy'),
            val_timeline.fpr_995)
    np.save(os.path.join(args.save_dir, 'val', 'fpr_997.npy'),
            val_timeline.fpr_997)
    np.save(os.path.join(args.save_dir, 'val', 'fpr_999.npy'),
            val_timeline.fpr_999)
    np.save(os.path.join(args.save_dir, 'val', 'fpr_1.npy'),
            val_timeline.fpr_1)
    print("other validating details are saved to {}".format(
        os.path.join(args.save_dir, 'val')))

    np.save(os.path.join(args.save_dir, 'test', 'loss.npy'),
            test_timeline.loss)
    np.save(os.path.join(args.save_dir, 'test', 'acc.npy'), test_timeline.acc)
    np.save(os.path.join(args.save_dir, 'test', 'loss_bi.npy'),
            test_timeline.loss_bi)
    np.save(os.path.join(args.save_dir, 'test', 'acc_bi.npy'),
            test_timeline.acc_bi)
    np.save(os.path.join(args.save_dir, 'test', 'loss_class.npy'),
            test_loss_class)
    np.save(os.path.join(args.save_dir, 'test', 'acc_class.npy'),
            test_acc_class)
    np.save(os.path.join(args.save_dir, 'test', 'loss_bi_class.npy'),
            test_loss_bi_class)
    np.save(os.path.join(args.save_dir, 'test', 'acc_bi_class.npy'),
            test_acc_bi_class)
    np.save(os.path.join(args.save_dir, 'test', 'margin_error.npy'),
            test_timeline.margin_error_bi)
    np.save(os.path.join(args.save_dir, 'test', 'margin_error_bi.npy'),
            test_timeline.margin_error_bi)
    np.save(os.path.join(args.save_dir, 'test', 'margin_error_class.npy'),
            test_me_class)
    np.save(os.path.join(args.save_dir, 'test', 'margin_error_bi_class.npy'),
            test_me_bi_class)
    np.save(os.path.join(args.save_dir, 'test', 'auc.npy'), test_timeline.auc)
    np.save(os.path.join(args.save_dir, 'test', 'fpr_991.npy'),
            test_timeline.fpr_991)
    np.save(os.path.join(args.save_dir, 'test', 'fpr_993.npy'),
            test_timeline.fpr_993)
    np.save(os.path.join(args.save_dir, 'test', 'fpr_995.npy'),
            test_timeline.fpr_995)
    np.save(os.path.join(args.save_dir, 'test', 'fpr_997.npy'),
            test_timeline.fpr_997)
    np.save(os.path.join(args.save_dir, 'test', 'fpr_999.npy'),
            test_timeline.fpr_999)
    np.save(os.path.join(args.save_dir, 'test', 'fpr_1.npy'),
            test_timeline.fpr_1)
    print("other testing details are saved to {}".format(
        os.path.join(args.save_dir, 'test')))
Пример #26
0
def train(epoch_num=10,
          milestone=None,
          optim_type='Adam',
          lr_d=1e-4,
          lr_g=1e-4,
          startPoint=None,
          start_n=0,
          z_dim=128,
          batchsize=64,
          loss_name='WGAN',
          model_name='dc',
          model_config=None,
          data_path='None',
          show_iter=100,
          logdir='test',
          dataname='cifar10',
          device='cpu',
          gpu_num=1,
          saturating=False):
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    writer = SummaryWriter(log_dir='logs/%s/%s' % (logdir, current_time))
    d_optimizer = Adam(D.parameters(), lr=lr_d, betas=(0.5, 0.999))
    g_optimizer = Adam(G.parameters(), lr=lr_g, betas=(0.5, 0.999))
    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        d_optimizer.load_state_dict(chk['d_optim'])
        g_optimizer.load_state_dict(chk['g_optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)

    for e in range(epoch_num):
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            d_loss = get_loss(name=loss_name,
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake)
            d_optimizer.zero_grad()
            g_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            if not saturating:
                if 'DCGAN' in model_name:
                    z = torch.randn((d_real.shape[0], z_dim, 1, 1),
                                    device=device)
                else:
                    z = torch.randn((d_real.shape[0], z_dim), device=device)
                fake_x = G(z)
                d_fake = D(fake_x)
                g_loss = get_loss(name=loss_name, g_loss=True, d_fake=d_fake)
                g_optimizer.zero_grad()
                g_loss.backward()
            else:
                g_loss = d_loss
            g_optimizer.step()

            writer.add_scalar('Loss/D loss', d_loss.item(), count)
            writer.add_scalar('Loss/G loss', g_loss.item(), count)
            writer.add_scalars('Discriminator output', {
                'Generated image': d_fake.mean().item(),
                'Real image': d_real.mean().item()
            },
                               global_step=count)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter %d, D Loss: %.5f, G loss: %.5f, time: %.2f s' %
                      (count, d_loss.item(), g_loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(path=logdir,
                                name='%s-%s_%d.pth' %
                                (optim_type, model_name, count + start_n),
                                D=D,
                                G=G,
                                optimizer=d_optimizer,
                                g_optimizer=g_optimizer)
            count += 1
    writer.close()
Пример #27
0
def main():
    print(args)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    global best_prec1

    # prepare dataset
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    trainset = CIFAR10(root='~/datasets/CIFAR10',
                       train=True,
                       download=True,
                       transform=transform_train)
    train_loader = DataLoader(trainset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)
    num_classes = trainset.num_classes
    targets = np.asarray(trainset.targets)
    testset = CIFAR10(root='~/datasets/CIFAR10',
                      train=False,
                      download=True,
                      transform=transform_test)
    test_loader = DataLoader(testset,
                             batch_size=args.batch_size * 4,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=True)

    model = get_model(args, num_classes)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.cuda()

    criterion = get_loss(args, labels=targets, num_classes=num_classes)
    optimizer = get_optimizer(model, args)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optim_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = get_scheduler(optimizer, args)

    if args.evaluate:
        validate(test_loader, model)
        return

    print("*" * 40)
    for epoch in range(args.start_epoch, args.epochs + 1):
        scheduler.step(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        print("*" * 40)

        # evaluate on validation sets
        print("train:", end="\t")
        prec1 = validate(train_loader, model)
        print("test:", end="\t")
        prec1 = validate(test_loader, model)
        print("*" * 40)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if (epoch < 70
                and epoch % 10 == 0) or (epoch >= 70
                                         and epoch % args.save_freq == 0):
            filename = 'checkpoint_{}.tar'.format(epoch)
        else:
            filename = None
        save_checkpoint(args.save_dir, {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
            'best_prec1': best_prec1,
        },
                        is_best,
                        filename=filename)
Пример #28
0
optim_g = opt.Adam(filter(lambda x: x.requires_grad is not False,
                          model.g.parameters()),
                   lr=0.0008,
                   weight_decay=0.0005)

model.cuda()
model.train()
k = 0
for step, e in enumerate(range(50)):
    loader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)
    print("%d epoch" % (e + 1))
    for data, label in enumerate(loader):
        data = data.cuda()
        r_x_s, r_x_r, f_p_s, f_p_t, f_id_s, f_id_r, c_x_r, c_x_s = model(data)
        lamda = losses.update_lamda(step)
        ld, lc, lg = losses.get_loss(r_x_s, r_x_r, f_p_s, f_p_t, f_id_s,
                                     f_id_r, c_x_r, c_x_s, label, k, lamda)
        k = losses.update_k(k, r_x_r, r_x_s)

        optim_d.zero_grad()
        ld.backward(retain_grapg=True)
        optim_d.step()

        optim_c.zero_grad()
        lc.backward(retain_graph=True)
        optim_c.step()

        optim_g.zero_grad()
        lg.backward(retain_grad=True)
        optim_g.step()

        if (step + 1) % 50000 == 0:
Пример #29
0
def train_g(epoch_num=10,
            logdir='test',
            loss_name='JSD',
            show_iter=500,
            model_weight=None,
            load_d=False,
            load_g=False,
            device='cpu'):
    lr_d = 0.01
    lr_g = 0.01
    batchsize = 128
    z_dim = 96
    print('MNIST, discriminator lr: %.3f' % lr_d)
    dataset = get_data(dataname='MNIST', path='../datas/mnist')
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D = dc_D().to(device)
    G = dc_G(z_dim=z_dim).to(device)
    D.apply(weights_init_d)
    G.apply(weights_init_g)
    if model_weight is not None:
        chk = torch.load(model_weight)
        if load_d:
            D.load_state_dict(chk['D'])
            print('Load D from %s' % model_weight)
        if load_g:
            G.load_state_dict(chk['G'])
            print('Load G from %s' % model_weight)
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    # writer = SummaryWriter(log_dir='logs/%s/%s_%.3f' % (logdir, current_time, lr_g))
    d_optimizer = SGD(D.parameters(), lr=lr_d)
    g_optimizer = SGD(G.parameters(), lr=lr_g)
    timer = time.time()
    count = 0
    for e in range(epoch_num):
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            z = torch.randn((real_x.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            D_loss = get_loss(name=loss_name,
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake)
            G_loss = get_loss(name=loss_name,
                              g_loss=True,
                              d_real=d_real,
                              d_fake=d_fake)
            d_optimizer.zero_grad()
            g_optimizer.zero_grad()
            G_loss.backward()
            g_optimizer.step()
            print('D_loss: {}, G_loss: {}'.format(D_loss.item(),
                                                  G_loss.item()))
            # writer.add_scalars('Loss', {'D_loss': D_loss.item(),
            #                             'G_loss': G_loss.item()},
            #                    global_step=count)
            # writer.add_scalars('Discriminator output', {'Generated image': d_fake.mean().item(),
            #                                             'Real image': d_real.mean().item()},
            #                    global_step=count)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , D_loss: %.5f, G_loss: %.5f, time: %.3fs' %
                      (count, D_loss.item(), G_loss.item(), time_cost))
                timer = time.time()
                save_checkpoint(path=logdir,
                                name='FixD-%.3f_%d.pth' % (lr_d, count),
                                D=D,
                                G=G)
            count += 1
Пример #30
0
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--lr_step', type=int, default=150)
    parser.add_argument('--start_epoch', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=350)
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--print_freq', type=int, default=10)

    options = vars(parser.parse_args())
    return options


if __name__ == "__main__":
    options = argparser()

    # The following return loss classes
    criterion_cls = get_loss(loss_name='CE')  # Cross-entropy loss
    criterion_clust = get_loss(loss_name='ClusterLoss')  # MEL + BEL

    torch.multiprocessing.set_sharing_strategy('file_system')
    cudnn.benchmark = True

    if options['mode'] == 'train':
        if options['type'] == 'cls_clust':
            # Use only classification and clustering (MEL + BEL)
            options['gamma'] = 0
        elif options['type'] == 'cls':
            # Use only classification
            options['gamma'] = 0
            options['alpha'] = 0
            options['beta'] = 0
        elif options['type'] == 'cls_MEL':