def test(dataset, data_split, label_split, model, logger, epoch):
    with torch.no_grad():
        metric = Metric()
        model.train(False)
        for m in range(cfg['num_users']):
            data_loader = make_data_loader({'test': SplitDataset(dataset, data_split[m])})['test']
            for i, input in enumerate(data_loader):
                input = collate(input)
                input_size = input['img'].size(0)
                input['label_split'] = torch.tensor(label_split[m])
                input = to_device(input, cfg['device'])
                output = model(input)
                output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
                evaluation = metric.evaluate(cfg['metric_name']['test']['Local'], input, output)
                logger.append(evaluation, 'test', input_size)
        data_loader = make_data_loader({'test': dataset})['test']
        for i, input in enumerate(data_loader):
            input = collate(input)
            input_size = input['img'].size(0)
            input = to_device(input, cfg['device'])
            output = model(input)
            output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
            evaluation = metric.evaluate(cfg['metric_name']['test']['Global'], input, output)
            logger.append(evaluation, 'test', input_size)
        info = {'info': ['Model: {}'.format(cfg['model_tag']),
                         'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]}
        logger.append(info, 'test', mean=False)
        logger.write('test', cfg['metric_name']['test']['Local'] + cfg['metric_name']['test']['Global'])
    return
def runExperiment():
    seed = int(cfg['model_tag'].split('_')[0])
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
    process_dataset(dataset['train'])
    if cfg['raw']:
        data_loader = make_data_loader(dataset)['train']
        metric = Metric()
        img = []
        for i, input in enumerate(data_loader):
            input = collate(input)
            img.append(input['img'])
        img = torch.cat(img, dim=0)
        output = {'img': img}
        evaluation = metric.evaluate(cfg['metric_name']['test'], None, output)
        is_result, fid_result = evaluation['InceptionScore'], evaluation['FID']
        print('Inception Score ({}): {}'.format(cfg['data_name'], is_result))
        print('FID ({}): {}'.format(cfg['data_name'], fid_result))
        save(is_result,
             './output/result/is_generated_{}.npy'.format(cfg['data_name']),
             mode='numpy')
        save(fid_result,
             './output/result/fid_generated_{}.npy'.format(cfg['data_name']),
             mode='numpy')
    else:
        generated = np.load('./output/npy/generated_{}.npy'.format(
            cfg['model_tag']),
                            allow_pickle=True)
        test(generated)
    return
def train(data_loader, model, optimizer, logger, epoch):
    metric = Metric()
    model.train(True)
    for i, input in enumerate(data_loader):
        start_time = time.time()
        input = collate(input)
        input_size = len(input['img'])
        input = to_device(input, config.PARAM['device'])
        model.zero_grad()
        output = model(input)
        output['loss'] = output['loss'].mean() if config.PARAM['world_size'] > 1 else output['loss']
        output['loss'].backward()
        optimizer.step()
        if i % int((len(data_loader) * config.PARAM['log_interval']) + 1) == 0:
            batch_time = time.time() - start_time
            lr = optimizer.param_groups[0]['lr']
            epoch_finished_time = datetime.timedelta(seconds=round(batch_time * (len(data_loader) - i - 1)))
            exp_finished_time = epoch_finished_time + datetime.timedelta(
                seconds=round((config.PARAM['num_epochs'] - epoch) * batch_time * len(data_loader)))
            info = {'info': ['Model: {}'.format(config.PARAM['model_tag']),
                             'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * i / len(data_loader)),
                             'Learning rate: {}'.format(lr), 'Epoch Finished Time: {}'.format(epoch_finished_time),
                             'Experiment Finished Time: {}'.format(exp_finished_time)]}
            logger.append(info, 'train', mean=False)
            evaluation = metric.evaluate(config.PARAM['metric_names']['train'], input, output)
            logger.append(evaluation, 'train', n=input_size)
            logger.write('train', config.PARAM['metric_names']['train'])
    return
Esempio n. 4
0
def test(data_loader, ae, model, logger, epoch):
    with torch.no_grad():
        metric = Metric()
        ae.train(False)
        model.train(False)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input_size = input['img'].size(0)
            input = to_device(input, cfg['device'])
            _, _, input['img'] = ae.encode(input['img'])
            input['img'] = input['img'].detach()
            output = model(input)
            output['loss'] = output['loss'].mean(
            ) if cfg['world_size'] > 1 else output['loss']
            evaluation = metric.evaluate(cfg['metric_name']['test'], input,
                                         output)
            logger.append(evaluation, 'test', input_size)
        logger.append(evaluation, 'test')
        info = {
            'info': [
                'Model: {}'.format(cfg['model_tag']),
                'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)
            ]
        }
        logger.append(info, 'test', mean=False)
        logger.write('test', cfg['metric_name']['test'])
    return
Esempio n. 5
0
def runExperiment():
    seed = int(cfg['model_tag'].split('_')[0])
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
    process_dataset(dataset['train'])
    if cfg['raw']:
        data_loader = make_data_loader(dataset)['train']
        metric = Metric()
        img, label = [], []
        for i, input in enumerate(data_loader):
            input = collate(input)
            img.append(input['img'])
            label.append(input['label'])
        img = torch.cat(img, dim=0)
        label = torch.cat(label, dim=0)
        output = {'img': img, 'label': label}
        evaluation = metric.evaluate(cfg['metric_name']['test'], None, output)
        dbi_result = evaluation['DBI']
        print('Davies-Bouldin Index ({}): {}'.format(cfg['data_name'],
                                                     dbi_result))
        save(dbi_result,
             './output/result/dbi_created_{}.npy'.format(cfg['data_name']),
             mode='numpy')
    else:
        created = np.load('./output/npy/created_{}.npy'.format(
            cfg['model_tag']),
                          allow_pickle=True)
        test(created)
    return
def test(data_loader, model, logger, epoch):
    with torch.no_grad():
        metric = Metric()
        model.train(False)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input_size = input['img'].size(0)
            input = to_device(input, cfg['device'])
            output = model(input)
            output['loss'] = output['loss'].mean(
            ) if cfg['world_size'] > 1 else output['loss']
            evaluation = metric.evaluate(cfg['metric_name']['test'], input,
                                         output)
            logger.append(evaluation, 'test', input_size)
        logger.append(evaluation, 'test')
        info = {
            'info': [
                'Model: {}'.format(cfg['model_tag']),
                'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)
            ]
        }
        logger.append(info, 'test', mean=False)
        logger.write('test', cfg['metric_name']['test'])
        if cfg['show']:
            input['reconstruct'] = True
            input['z'] = output['z']
            output = model.reverse(input)
            save_img(input['img'][:100],
                     './output/vis/input_{}.png'.format(cfg['model_tag']),
                     range=(-1, 1))
            save_img(output['img'][:100],
                     './output/vis/output_{}.png'.format(cfg['model_tag']),
                     range=(-1, 1))
    return
Esempio n. 7
0
def stats(data_loader, model):
    with torch.no_grad():
        model.train(True)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input = to_device(input, cfg['device'])
            model(input)
    return
def stats(dataset, model):
    with torch.no_grad():
        data_loader = make_data_loader({'train': dataset})['train']
        model.train(True)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input = to_device(input, cfg['device'])
            model(input)
    return
Esempio n. 9
0
def test(data_loader):
    with torch.no_grad():
        generated = []
        for i, input in enumerate(data_loader):
            input = collate(input)
            generated.append(input['img'])
        generated = torch.cat(generated)
        generated = (generated + 1) / 2 * 255
        save(generated.numpy(), './output/npy/generated_0_{}.npy'.format(cfg['data_name']), mode='numpy')
    return
def stats(dataset, model):
    with torch.no_grad():
        test_model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"])'
                          .format(cfg['model_name']))
        test_model.load_state_dict(model.state_dict(), strict=False)
        data_loader = make_data_loader({'train': dataset})['train']
        test_model.train(True)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input = to_device(input, cfg['device'])
            test_model(input)
    return test_model
Esempio n. 11
0
def train(data_loader, ae, model, optimizer, logger, epoch):
    metric = Metric()
    ae.train(False)
    model.train(True)
    start_time = time.time()
    for i, input in enumerate(data_loader):
        input = collate(input)
        input_size = input['img'].size(0)
        input = to_device(input, cfg['device'])
        with torch.no_grad():
            _, _, input['img'] = ae.encode(input['img'])
            input['img'] = input['img'].detach()
        optimizer.zero_grad()
        output = model(input)
        output['loss'] = output['loss'].mean(
        ) if cfg['world_size'] > 1 else output['loss']
        output['loss'].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        evaluation = metric.evaluate(cfg['metric_name']['train'], input,
                                     output)
        logger.append(evaluation, 'train', n=input_size)
        if i % int((len(data_loader) * cfg['log_interval']) + 1) == 0:
            batch_time = (time.time() - start_time) / (i + 1)
            lr = optimizer.param_groups[0]['lr']
            epoch_finished_time = datetime.timedelta(
                seconds=round(batch_time * (len(data_loader) - i - 1)))
            exp_finished_time = epoch_finished_time + datetime.timedelta(
                seconds=round((cfg['num_epochs'] - epoch) * batch_time *
                              len(data_loader)))
            info = {
                'info': [
                    'Model: {}'.format(cfg['model_tag']),
                    'Train Epoch: {}({:.0f}%)'.format(
                        epoch, 100. * i / len(data_loader)),
                    'Learning rate: {}'.format(lr),
                    'Epoch Finished Time: {}'.format(epoch_finished_time),
                    'Experiment Finished Time: {}'.format(exp_finished_time)
                ]
            }
            logger.append(info, 'train', mean=False)
            logger.write('train', cfg['metric_name']['train'])
    return
 def train(self, local_parameters, lr, logger):
     metric = Metric()
     model = eval('models.{}(model_rate=self.model_rate).to(cfg["device"])'.format(cfg['model_name']))
     model.load_state_dict(local_parameters)
     model.train(True)
     optimizer = make_optimizer(model, lr)
     for local_epoch in range(1, cfg['num_epochs']['local'] + 1):
         for i, input in enumerate(self.data_loader):
             input = collate(input)
             input_size = input['img'].size(0)
             input['label_split'] = torch.tensor(self.label_split)
             input = to_device(input, cfg['device'])
             optimizer.zero_grad()
             output = model(input)
             output['loss'].backward()
             torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
             optimizer.step()
             evaluation = metric.evaluate(cfg['metric_name']['train']['Local'], input, output)
             logger.append(evaluation, 'train', n=input_size)
     local_parameters = model.state_dict()
     return local_parameters
Esempio n. 13
0
def FID(img):
    with torch.no_grad():
        batch_size = 32
        cfg['batch_size']['train'] = batch_size
        dataset = fetch_dataset(cfg['data_name'], cfg['subset'], verbose=False)
        real_data_loader = make_data_loader(dataset)['train']
        generated_data_loader = DataLoader(img, batch_size=batch_size)
        if cfg['data_name'] in ['COIL100', 'Omniglot']:
            model = models.classifier().to(cfg['device'])
            model_tag = ['0', cfg['data_name'], cfg['subset'], 'classifier']
            model_tag = '_'.join(filter(None, model_tag))
            checkpoint = load(
                './metrics_tf/res/classifier/{}_best.pt'.format(model_tag))
            model.load_state_dict(checkpoint['model_dict'])
            model.train(False)
            real_feature = []
            for i, input in enumerate(real_data_loader):
                input = collate(input)
                input = to_device(input, cfg['device'])
                real_feature_i = model.feature(input)
                real_feature.append(real_feature_i.cpu().numpy())
            real_feature = np.concatenate(real_feature, axis=0)
            generated_feature = []
            for i, input in enumerate(generated_data_loader):
                input = {
                    'img': input,
                    'label': input.new_zeros(input.size(0)).long()
                }
                input = to_device(input, cfg['device'])
                generated_feature_i = model.feature(input)
                generated_feature.append(generated_feature_i.cpu().numpy())
            generated_feature = np.concatenate(generated_feature, axis=0)
        else:
            model = inception_v3(pretrained=True,
                                 transform_input=False).to(cfg['device'])
            up = nn.Upsample(size=(299, 299),
                             mode='bilinear',
                             align_corners=False)
            model.feature = nn.Sequential(*[
                up, model.Conv2d_1a_3x3, model.Conv2d_2a_3x3,
                model.Conv2d_2b_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2), model.Conv2d_3b_1x1,
                model.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2), model.Mixed_5b,
                model.Mixed_5c, model.Mixed_5d, model.Mixed_6a, model.Mixed_6b,
                model.Mixed_6c, model.Mixed_6d, model.Mixed_6e, model.Mixed_7a,
                model.Mixed_7b, model.Mixed_7c,
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten()
            ])
            model.train(False)
            real_feature = []
            for i, input in enumerate(real_data_loader):
                input = collate(input)
                input = to_device(input, cfg['device'])
                real_feature_i = model.feature(input['img'])
                real_feature.append(real_feature_i.cpu().numpy())
            real_feature = np.concatenate(real_feature, axis=0)
            generated_feature = []
            for i, input in enumerate(generated_data_loader):
                input = to_device(input, cfg['device'])
                generated_feature_i = model.feature(input)
                generated_feature.append(generated_feature_i.cpu().numpy())
            generated_feature = np.concatenate(generated_feature, axis=0)
        mu1 = np.mean(real_feature, axis=0)
        sigma1 = np.cov(real_feature, rowvar=False)
        mu2 = np.mean(generated_feature, axis=0)
        sigma2 = np.cov(generated_feature, rowvar=False)
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)
        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)
        assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
        assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
        diff = mu1 - mu2
        # product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            offset = np.eye(sigma1.shape[0]) * 1e-6
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
        # numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError("Imaginary component {}".format(m))
            covmean = covmean.real
        tr_covmean = np.trace(covmean)
        fid = diff.dot(diff) + np.trace(sigma1) + np.trace(
            sigma2) - 2 * tr_covmean
        fid = fid.item()
    return fid
Esempio n. 14
0
def summarize(data_loader, model):
    def register_hook(module):

        def hook(module, input, output):
            module_name = str(module.__class__.__name__)
            if module_name not in summary['count']:
                summary['count'][module_name] = 1
            else:
                summary['count'][module_name] += 1
            key = str(hash(module))
            if key not in summary['module']:
                summary['module'][key] = OrderedDict()
                summary['module'][key]['module_name'] = '{}_{}'.format(module_name, summary['count'][module_name])
                summary['module'][key]['input_size'] = []
                summary['module'][key]['output_size'] = []
                summary['module'][key]['params'] = {}
                summary['module'][key]['flops'] = make_flops(module, input, output)
            input_size, output_size = make_size(input, output)
            summary['module'][key]['input_size'].append(input_size)
            summary['module'][key]['output_size'].append(output_size)
            for name, param in module.named_parameters():
                if param.requires_grad:
                    if name in ['weight', 'in_proj_weight', 'out_proj.weight']:
                        if name not in summary['module'][key]['params']:
                            summary['module'][key]['params'][name] = {}
                            summary['module'][key]['params'][name]['size'] = list(param.size())
                            summary['module'][key]['coordinates'] = []
                            summary['module'][key]['params'][name]['mask'] = torch.zeros(
                                summary['module'][key]['params'][name]['size'], dtype=torch.long,
                                device=cfg['device'])
                    elif name in ['bias', 'in_proj_bias', 'out_proj.bias']:
                        if name not in summary['module'][key]['params']:
                            summary['module'][key]['params'][name] = {}
                            summary['module'][key]['params'][name]['size'] = list(param.size())
                            summary['module'][key]['params'][name]['mask'] = torch.zeros(
                                summary['module'][key]['params'][name]['size'], dtype=torch.long,
                                device=cfg['device'])
                    else:
                        continue
            if len(summary['module'][key]['params']) == 0:
                return
            for name in summary['module'][key]['params']:
                summary['module'][key]['params'][name]['mask'] += 1
            return

        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) \
                and not isinstance(module, nn.ModuleDict) and module != model:
            hooks.append(module.register_forward_hook(hook))
        return

    run_mode = True
    summary = OrderedDict()
    summary['module'] = OrderedDict()
    summary['count'] = OrderedDict()
    hooks = []
    model.train(run_mode)
    model.apply(register_hook)
    if cfg['data_name'] in ['MNIST', 'CIFAR10']:
        for i, input in enumerate(data_loader):
            input = collate(input)
            input = to_device(input, cfg['device'])
            model(input)
            break
    elif cfg['data_name'] in ['WikiText2']:
        dataset = BatchDataset(data_loader.dataset, cfg['bptt'])
        for i, input in enumerate(dataset):
            input = to_device(input, cfg['device'])
            model(input)
            break
    else:
        raise ValueError('Not valid data name')
    for h in hooks:
        h.remove()
    summary['total_num_params'] = 0
    summary['total_num_flops'] = 0
    for key in summary['module']:
        num_params = 0
        num_flops = 0
        for name in summary['module'][key]['params']:
            num_params += (summary['module'][key]['params'][name]['mask'] > 0).sum().item()
            num_flops += summary['module'][key]['flops']
        summary['total_num_params'] += num_params
        summary['total_num_flops'] += num_flops
    summary['total_space'] = summary['total_num_params'] * 32. / 8 / (1024 ** 2.)
    return summary
Esempio n. 15
0
def summarize(data_loader, model, ae=None):
    def register_hook(module):

        def hook(module, input, output):
            module_name = str(module.__class__.__name__)
            if module_name not in summary['count']:
                summary['count'][module_name] = 1
            else:
                summary['count'][module_name] += 1
            key = str(hash(module))
            if key not in summary['module']:
                summary['module'][key] = OrderedDict()
                summary['module'][key]['module_name'] = '{}_{}'.format(module_name, summary['count'][module_name])
                summary['module'][key]['input_size'] = []
                summary['module'][key]['output_size'] = []
                summary['module'][key]['params'] = {}
            input_size = make_size(input)
            output_size = make_size(output)
            summary['module'][key]['input_size'].append(input_size)
            summary['module'][key]['output_size'].append(output_size)
            for name, param in module.named_parameters():
                if param.requires_grad:
                    if name in ['weight', 'weight_orig']:
                        if name not in summary['module'][key]['params']:
                            summary['module'][key]['params']['weight'] = {}
                            summary['module'][key]['params']['weight']['size'] = list(param.size())
                            summary['module'][key]['coordinates'] = []
                            summary['module'][key]['params']['weight']['mask'] = torch.zeros(
                                summary['module'][key]['params']['weight']['size'], dtype=torch.long,
                                device=cfg['device'])
                    elif name == 'bias':
                        if name not in summary['module'][key]['params']:
                            summary['module'][key]['params']['bias'] = {}
                            summary['module'][key]['params']['bias']['size'] = list(param.size())
                            summary['module'][key]['params']['bias']['mask'] = torch.zeros(
                                summary['module'][key]['params']['bias']['size'], dtype=torch.long,
                                device=cfg['device'])
                    else:
                        continue
            if len(summary['module'][key]['params']) == 0:
                return
            if 'weight' in summary['module'][key]['params']:
                weight_size = summary['module'][key]['params']['weight']['size']
                summary['module'][key]['coordinates'].append(
                    [torch.arange(weight_size[i], device=cfg['device']) for i in range(len(weight_size))])
            else:
                raise ValueError('Not valid parametrized module')
            for name in summary['module'][key]['params']:
                coordinates = summary['module'][key]['coordinates'][-1]
                if name == 'weight':
                    if len(coordinates) == 1:
                        summary['module'][key]['params'][name]['mask'][coordinates[0]] += 1
                    elif len(coordinates) >= 2:
                        summary['module'][key]['params'][name]['mask'][
                            coordinates[0].view(-1, 1), coordinates[1].view(1, -1),] += 1
                    else:
                        raise ValueError('Not valid coordinates dimension')
                elif name == 'bias':
                    if len(coordinates) == 1:
                        summary['module'][key]['params'][name]['mask'] += 1
                    elif len(coordinates) >= 2:
                        summary['module'][key]['params'][name]['mask'] += 1
                    else:
                        raise ValueError('Not valid coordinates dimension')
                else:
                    raise ValueError('Not valid parameters type')
            return

        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) \
                and not isinstance(module, nn.ModuleDict) and module != model:
            hooks.append(module.register_forward_hook(hook))
        return

    run_mode = True
    summary = OrderedDict()
    summary['module'] = OrderedDict()
    summary['count'] = OrderedDict()
    hooks = []
    model.train(run_mode)
    model.apply(register_hook)
    for i, input in enumerate(data_loader):
        input = collate(input)
        input = to_device(input, cfg['device'])
        if ae is not None:
            with torch.no_grad():
                _, _, input['img'] = ae.encode(input['img'])
                input['img'] = input['img'].detach()
        model(input)
        break
    for h in hooks:
        h.remove()
    summary['total_num_param'] = 0
    for key in summary['module']:
        num_params = 0
        for name in summary['module'][key]['params']:
            num_params += (summary['module'][key]['params'][name]['mask'] > 0).sum().item()
        summary['total_num_param'] += num_params
    summary['total_space_param'] = abs(summary['total_num_param'] * 32. / 8 / (1024 ** 2.))
    return summary
def runExperiment():
    seed = int(cfg['model_tag'].split('_')[0])
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
    process_dataset(dataset['train'])
    data_loader = make_data_loader(dataset)
    model = eval('models.{}().to(cfg["device"])'.format(cfg['model_name']))
    init_batches = {'img': [], 'label': []}
    with torch.no_grad():
        for input in islice(data_loader['train'], None,
                            cfg['num_init_batches']):
            for k in init_batches:
                init_batches[k].extend(input[k])
        init_batches = collate(init_batches)
        init_batches = to_device(init_batches, cfg['device'])
        model(init_batches)
    optimizer = make_optimizer(model)
    scheduler = make_scheduler(optimizer)
    if cfg['resume_mode'] == 1:
        last_epoch, model, optimizer, scheduler, logger = resume(
            model, cfg['model_tag'], optimizer, scheduler)
    elif cfg['resume_mode'] == 2:
        last_epoch = 1
        _, model, _, _, _ = resume(model, cfg['model_tag'])
        logger_path = 'output/runs/{}_{}'.format(
            cfg['model_tag'],
            datetime.datetime.now().strftime('%b%d_%H-%M-%S'))
        logger = Logger(logger_path)
    else:
        last_epoch = 1
        logger_path = 'output/runs/train_{}_{}'.format(
            cfg['model_tag'],
            datetime.datetime.now().strftime('%b%d_%H-%M-%S'))
        logger = Logger(logger_path)
    if cfg['world_size'] > 1:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          cfg['world_size'])))
    for epoch in range(last_epoch, cfg['num_epochs'] + 1):
        logger.safe(True)
        train(data_loader['train'], model, optimizer, logger, epoch)
        test(data_loader['train'], model, logger, epoch)
        if cfg['scheduler_name'] == 'ReduceLROnPlateau':
            scheduler.step(
                metrics=logger.mean['test/{}'.format(cfg['pivot_metric'])])
        else:
            scheduler.step()
        logger.safe(False)
        model_state_dict = model.module.state_dict(
        ) if cfg['world_size'] > 1 else model.state_dict()
        save_result = {
            'cfg': cfg,
            'epoch': epoch + 1,
            'model_dict': model_state_dict,
            'optimizer_dict': optimizer.state_dict(),
            'scheduler_dict': scheduler.state_dict(),
            'logger': logger
        }
        save(save_result,
             './output/model/{}_checkpoint.pt'.format(cfg['model_tag']))
        if cfg['pivot'] > logger.mean['test/{}'.format(cfg['pivot_metric'])]:
            cfg['pivot'] = logger.mean['test/{}'.format(cfg['pivot_metric'])]
            shutil.copy(
                './output/model/{}_checkpoint.pt'.format(cfg['model_tag']),
                './output/model/{}_best.pt'.format(cfg['model_tag']))
        logger.reset()
    logger.safe(False)
    return
def train(data_loader, model, optimizer, logger, epoch):
    metric = Metric()
    model.train(True)
    start_time = time.time()
    for i, input in enumerate(data_loader):
        input = collate(input)
        input_size = input['img'].size(0)
        input = to_device(input, cfg['device'])
        ############################
        # (1) Update D network
        ###########################
        for _ in range(cfg['iter']['discriminator']):
            # train with real
            optimizer['discriminator'].zero_grad()
            optimizer['generator'].zero_grad()
            D_x = model.discriminate(input['img'], input[cfg['subset']])
            # train with fake
            z1 = torch.randn(input['img'].size(0), cfg['gan']['latent_size'], device=cfg['device'])
            generated = model.generate(input[cfg['subset']], z1)
            D_G_z1 = model.discriminate(generated.detach(), input[cfg['subset']])
            if cfg['loss_type'] == 'BCE':
                D_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    D_x, torch.ones((input['img'].size(0), 1), device=cfg['device'])) + \
                         torch.nn.functional.binary_cross_entropy_with_logits(
                             D_G_z1, torch.zeros((input['img'].size(0), 1), device=cfg['device']))
            elif cfg['loss_type'] == 'Hinge':
                D_loss = torch.nn.functional.relu(1.0 - D_x).mean() + torch.nn.functional.relu(1.0 + D_G_z1).mean()
            else:
                raise ValueError('Not valid loss type')
            D_loss.backward()
            optimizer['discriminator'].step()
        ############################
        # (2) Update G network
        ###########################
        for _ in range(cfg['iter']['generator']):
            optimizer['discriminator'].zero_grad()
            optimizer['generator'].zero_grad()
            z2 = torch.randn(input['img'].size(0), cfg['gan']['latent_size'], device=cfg['device'])
            generated = model.generate(input[cfg['subset']], z2)
            D_G_z2 = model.discriminate(generated, input[cfg['subset']])
            if cfg['loss_type'] == 'BCE':
                G_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    D_G_z2, torch.ones((input['img'].size(0), 1), device=cfg['device']))
            elif cfg['loss_type'] == 'Hinge':
                G_loss = -D_G_z2.mean()
            else:
                raise ValueError('Not valid loss type')
            G_loss.backward()
            optimizer['generator'].step()
        output = {'loss': abs(D_loss - G_loss), 'loss_D': D_loss, 'loss_G': G_loss}
        evaluation = metric.evaluate(cfg['metric_name']['train'], input, output)
        logger.append(evaluation, 'train', n=input_size)
        if i % int((len(data_loader) * cfg['log_interval']) + 1) == 0:
            batch_time = (time.time() - start_time) / (i + 1)
            generator_lr, discriminator_lr = optimizer['generator'].param_groups[0]['lr'], \
                                             optimizer['discriminator'].param_groups[0]['lr']
            epoch_finished_time = datetime.timedelta(seconds=round(batch_time * (len(data_loader) - i - 1)))
            exp_finished_time = epoch_finished_time + datetime.timedelta(
                seconds=round((cfg['num_epochs'] - epoch) * batch_time * len(data_loader)))
            info = {'info': ['Model: {}'.format(cfg['model_tag']),
                             'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * i / len(data_loader)),
                             'Learning rate : (G: {}, D: {})'.format(generator_lr, discriminator_lr),
                             'Epoch Finished Time: {}'.format(epoch_finished_time),
                             'Experiment Finished Time: {}'.format(exp_finished_time)]}
            logger.append(info, 'train', mean=False)
            logger.write('train', cfg['metric_name']['train'])
    return