コード例 #1
0
    dummy_img = torch.rand(32, 3, 64, 64)  # output from network
    if n_iter % 10 == 0:
        x = vutils.make_grid(dummy_img, normalize=True, scale_each=True)
        writer.add_image('Image', x, n_iter)

        dummy_audio = torch.zeros(sample_rate * 2)
        for i in range(x.size(0)):
            # amplitude of sound should in [-1, 1]
            dummy_audio[i] = np.cos(freqs[n_iter // 10] * np.pi * float(i) /
                                    float(sample_rate))
        writer.add_audio('myAudio',
                         dummy_audio,
                         n_iter,
                         sample_rate=sample_rate)

        writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter)

        for name, param in resnet18.named_parameters():
            writer.add_histogram(name,
                                 param.clone().cpu().data.numpy(), n_iter)

        # needs tensorboard 0.4RC or later
        writer.add_pr_curve('xoxo', np.random.randint(2, size=100),
                            np.random.rand(100), n_iter)

dataset = datasets.MNIST('mnist', train=False, download=True)
images = dataset.test_data[:100].float()
label = dataset.test_labels[:100]

features = images.view(100, 784)
writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1))
コード例 #2
0
if GRADIENT_PENALTY:
    OUTPUT_PATH = os.path.join(OUTPUT_PATH, '%s_%s-gp' % (MODEL, MODE),
                               '%s/lrd=%.1e_lrg=%.1e/s%i/%i' %
                               ('aybat_adam', LEARNING_RATE_D,
                                LEARNING_RATE_G, SEED, int(time.time())))
else:
    OUTPUT_PATH = os.path.join(OUTPUT_PATH, '%s_%s' % (MODEL, MODE),
                               '%s/lrd=%.1e_lrg=%.1e/s%i/%i' %
                               ('aybat_adam', LEARNING_RATE_D,
                                LEARNING_RATE_G, SEED, int(time.time())))

if TENSORBOARD_FLAG:
    from tensorboardX import SummaryWriter
    writer = SummaryWriter(log_dir=os.path.join(OUTPUT_PATH, 'tensorboard'))
    writer.add_text('config', json.dumps(vars(args), indent=2, sort_keys=True))


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, num_workers=1)

print 'Init....'
if not os.path.exists(os.path.join(OUTPUT_PATH, 'checkpoints')):
    os.makedirs(os.path.join(OUTPUT_PATH, 'checkpoints'))
if not os.path.exists(os.path.join(OUTPUT_PATH, 'gen')):
    os.makedirs(os.path.join(OUTPUT_PATH, 'gen'))
コード例 #3
0
    try:
        args.device = [int(item) for item in args.device.split(',')]
    except AttributeError:
        args.device = [int(args.device)]
    args.modeldevice = args.device
    util.setup_runtime(seed=42,
                       cuda_dev_id=list(
                           np.unique(args.modeldevice + args.device)))
    print(args, flush=True)
    print()
    print(name, flush=True)
    time.sleep(5)

    writer = SummaryWriter('./runs/%s' % name)
    writer.add_text(
        'args',
        " \n".join(['%s %s' % (arg, getattr(args, arg))
                    for arg in vars(args)]))

    # Setup model and train_loader
    model, train_loader = return_model_loader(args)
    print(len(train_loader.dataset))
    model.to('cuda:0')
    if torch.cuda.device_count() > 1:
        print("Let's use", len(args.modeldevice), "GPUs for the model")
        if len(args.modeldevice) == 1:
            print('single GPU model', flush=True)
        else:
            model.features = nn.DataParallel(model.features,
                                             device_ids=list(
                                                 range(len(args.modeldevice))))
    # Setup optimizer
コード例 #4
0
ファイル: train_new.py プロジェクト: chengxiaoy/ULIS
def train_(model, train_dataloader, valid_dataloader, early_stopping, group_id,
           index, config):
    writer = SummaryWriter(
        logdir=os.path.join("board/", str(config.expriment_id)))
    criterion = get_criterion(config)
    optimizer = get_optimizer(config, model)
    if config.use_swa:
        optimizer = torchcontrib.optim.SWA(optimizer)
    schedular = get_schedular(config, optimizer, len(train_dataloader))

    for epoch in range(config.EPOCHS):
        start_time = time.time()
        print("Epoch : {}".format(epoch))
        # print("learning_rate: {:0.9f}".format(schedular.get_lr()[0]))
        train_losses, valid_losses = [], []

        model.train()  # prep model for training
        train_preds, train_true = torch.Tensor([]).to(
            config.device), torch.LongTensor([]).to(config.device)
        ii = 0
        for x, y in train_dataloader:
            x = x.to(config.device)
            y = y.to(config.device)

            optimizer.zero_grad()
            predictions = model(x)

            predictions_ = predictions.reshape(-1, predictions.shape[-1])
            y_ = y.reshape(-1)

            loss = criterion(predictions_, y_)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            if config.schedular == 'cos' or config.schedular == 'cyc':
                schedular.step()
            if config.use_swa:
                if ii >= 10 and ii % 2 == 0:
                    optimizer.update_swa()
                ii += 1
            # record training lossa
            train_losses.append(loss.item())

            train_true = torch.cat([train_true, y_], 0)
            train_preds = torch.cat([train_preds, predictions_], 0)
        if config.use_swa:
            optimizer.swap_swa_sgd()

        model.eval()  # prep model for evaluation
        val_preds, val_true = torch.Tensor([]).to(
            config.device), torch.LongTensor([]).to(config.device)
        with torch.no_grad():
            for x, y in valid_dataloader:
                x = x.to(config.device)
                y = y.to(config.device)

                predictions = model(x)
                predictions_ = predictions.reshape(-1, predictions.shape[-1])
                y_ = y.reshape(-1)

                loss = criterion(predictions_, y_)
                valid_losses.append(loss.item())

                val_true = torch.cat([val_true, y_], 0)
                val_preds = torch.cat([val_preds, predictions_], 0)

        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        print("train_loss: {:0.6f}, valid_loss: {:0.6f}".format(
            train_loss, valid_loss))

        train_true = train_true.cpu().detach().numpy()
        train_preds = train_preds.cpu().detach().numpy().argmax(1)
        train_score = f1_score(train_true,
                               train_preds,
                               labels=np.unique(train_true),
                               average='macro')
        train_accurancy = np.sum(train_true == train_preds) / len(train_true)

        val_true = val_true.cpu().detach().numpy()
        val_preds = val_preds.cpu().detach().numpy().argmax(1)
        val_score = f1_score(val_true,
                             val_preds,
                             labels=np.unique(val_true),
                             average='macro')

        if config.schedular == 'reduce':
            schedular.step(val_score)
        val_accurancy = np.sum(val_true == val_preds) / len(val_true)

        print("train_f1: {:0.6f}, valid_f1: {:0.6f}".format(
            train_score, val_score))
        # print("train_acc: {:0.6f}, valid_acc: {:0.6f}".format(train_accurancy, val_accurancy))

        writer.add_scalars('group_{}/cv_{}/loss'.format(group_id, index), {
            'train': train_loss,
            'val': valid_loss
        }, epoch)
        writer.add_scalars('group_{}/cv_{}/f1_score'.format(group_id, index), {
            'train': train_score,
            'val': val_score
        }, epoch)
        # writer.add_scalars('group_{}/cv_{}/acc'.format(group_id, index),
        #                    {'train': train_accurancy, 'val': val_accurancy},
        #                    epoch)
        if config.early_stop_max:
            metric = val_score
        else:
            metric = valid_loss
        if early_stopping(metric, model) == 2:
            if config.use_swa and config.use_cbr:
                optimizer.bn_update(train_dataloader, model, config.device)
                early_stopping.save_model(model)
            print("Early Stopping...")
            print("Best Val Score: {:0.6f}".format(early_stopping.best_score))
            writer.add_text("val_score", "valid_f1_score_{}".format(val_score),
                            index)
            break

        if config.use_swa:
            optimizer.swap_swa_sgd()

        print("--- %s seconds ---" % (time.time() - start_time))
コード例 #5
0
import torch
from tensorboardX import SummaryWriter
import time

writer = SummaryWriter('train_log')
x = torch.FloatTensor([100])
y = torch.FloatTensor([500])

for epoch in range(100):
    x /= 1.5
    y /= 1.5
    loss = y - x
    print(loss)
    writer.add_histogram('zz/x', x, epoch)
    writer.add_histogram('zz/y', y, epoch)
    writer.add_scalar('data/x', x, epoch)
    writer.add_scalar('data/y', y, epoch)
    writer.add_scalar('data/loss', loss, epoch)
    writer.add_scalars('data/scalar_group', {'x': x,
                                             'y': y,
                                             'loss': loss}, epoch)
    writer.add_text('zz/text', 'zz: this is epoch ' + str(epoch), epoch)
    time.sleep(0.5)

# export scalar data to JSON for external processing
writer.export_scalars_to_json("./test.json")
writer.close()
コード例 #6
0
"noise_dim" : 5,
"optim_type":"adam"
}

epoches = params["epoches"]
batch_size = params["batch_size"]
lr = params["lr"]
hidden_dim = params["hidden_dim"]
noise_dim = params["noise_dim"]
optim_type = params["optim_type"]

# -----------------
#  log params
# -----------------
for k,v in params.items():
    writer.add_text(k,str(v))


filepath = 'C:\\Users\Snow\Desktop\大三下\模式识别与深度学习\Dlab\lab5\points.mat'

points_set = utils.Points(filepath)
data_loader = torch.utils.data.DataLoader(points_set, batch_size=batch_size, shuffle=True, num_workers=0)

discriminator = Discriminator(2,hidden_dim,1)
generator = Generator(noise_dim,hidden_dim,2)

discriminator.cuda()
generator.cuda()

if optim_type == "rmsprop":
    optimizer_D = optim.RMSprop(discriminator.parameters(), lr = lr)
コード例 #7
0
                    help='dimention of encoded vector')
parser.add_argument('--batchsize',
                    '-b',
                    type=int,
                    default=100,
                    help='learning minibatch size')
parser.add_argument('--test',
                    action='store_true',
                    help='Use tiny datasets for quick tests')
args = parser.parse_args()

batchsize = args.batchsize
n_epoch = args.epoch
n_latent = args.dimz

writer.add_text('config', str(args))

print('GPU: {}'.format(args.gpu))
print('# dim z: {}'.format(args.dimz))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')

# Prepare dataset
print('load MNIST dataset')
mnist = data.load_mnist_data()
mnist['data'] = mnist['data'].astype(np.float32)
mnist['data'] /= 255
mnist['target'] = mnist['target'].astype(np.int32)

if args.test:
コード例 #8
0
ファイル: train.py プロジェクト: ameypatil10/CheXpert-CXR
def train(resume_path=None, jigsaw_path=None):

    writer = SummaryWriter('../runs/'+hparams.exp_name)

    for k in hparams.__dict__.keys():
        writer.add_text(str(k), str(hparams.__dict__[k]))

    train_dataset = ChestData(data_csv=hparams.train_csv, data_dir=hparams.train_dir, augment=hparams.augment,
                        transform=transforms.Compose([
                            transforms.Resize(hparams.image_shape),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915))
                        ]))

    validation_dataset = ChestData(data_csv=hparams.valid_csv, data_dir=hparams.valid_dir,
                        transform=transforms.Compose([
                            transforms.Resize(hparams.image_shape),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915))
                        ]))

    # train_sampler = WeightedRandomSampler()

    train_loader = DataLoader(train_dataset, batch_size=hparams.batch_size,
                            shuffle=True, num_workers=2)

    validation_loader = DataLoader(validation_dataset, batch_size=hparams.batch_size,
                            shuffle=True, num_workers=2)

    print('loaded train data of length : {}'.format(len(train_dataset)))

    adversarial_loss = torch.nn.BCELoss().to(hparams.gpu_device)
    discriminator = Discriminator().to(hparams.gpu_device)

    if hparams.cuda:
        discriminator = nn.DataParallel(discriminator, device_ids=hparams.device_ids)

    params_count = 0
    for param in discriminator.parameters():
        params_count += np.prod(param.size())
    print('Model has {0} trainable parameters'.format(params_count))

    if not hparams.pretrained:
#         discriminator.apply(weights_init_normal)
        pass
    if jigsaw_path:
        jigsaw = Jigsaw().to(hparams.gpu_device)
        if hparams.cuda:
            jigsaw = nn.DataParallel(jigsaw, device_ids=hparams.device_ids)
        checkpoints = torch.load(jigsaw_path, map_location=hparams.gpu_device)
        jigsaw.load_state_dict(checkpoints['discriminator_state_dict'])
        discriminator.module.model.features = jigsaw.module.feature.features
        print('loaded pretrained feature extractor from {} ..'.format(jigsaw_path))

    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=hparams.learning_rate, betas=(0.9, 0.999))

    scheduler_D = ReduceLROnPlateau(optimizer_D, mode='min', factor=0.1, patience=1, verbose=True, cooldown=0)

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    def validation(discriminator, send_stats=False, epoch=0):
        print('Validating model on {0} examples. '.format(len(validation_dataset)))
        discriminator_ = discriminator.eval()

        with torch.no_grad():
            pred_logits_list = []
            labels_list = []

            for (img, labels, imgs_names) in tqdm(validation_loader):
                img = Variable(img.float(), requires_grad=False)
                labels = Variable(labels.float(), requires_grad=False)

                img_ = img.to(hparams.gpu_device)
                labels = labels.to(hparams.gpu_device)
                
                pred_logits = discriminator_(img_)

                pred_logits_list.append(pred_logits)
                labels_list.append(labels)

            pred_logits = torch.cat(pred_logits_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            val_loss = adversarial_loss(pred_logits, labels)

        return accuracy_metrics(labels.long(), pred_logits), val_loss

    print('Starting training.. (log saved in:{})'.format(hparams.exp_name))
    start_time = time.time()
    best_valid_auc = 0

    # print(model)
    for epoch in range(hparams.num_epochs):
        for batch, (imgs, labels, imgs_name) in enumerate(tqdm(train_loader)):

            imgs = Variable(imgs.float(), requires_grad=False)
            labels = Variable(labels.float(), requires_grad=False)

            imgs_ = imgs.to(hparams.gpu_device)
            labels = labels.to(hparams.gpu_device)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            pred_logits = discriminator(imgs_)

            d_loss = adversarial_loss(pred_logits, labels)

            d_loss.backward()
            optimizer_D.step()

            writer.add_scalar('d_loss', d_loss.item(), global_step=batch+epoch*len(train_loader))

            pred_labels = (pred_logits >= hparams.thresh)
            pred_labels = pred_labels.float()

            # if batch % hparams.print_interval == 0:
            #     auc, f1, acc, _, _ = accuracy_metrics(pred_labels, labels.long(), pred_logits)
            #     print('[Epoch - {0:.1f}, batch - {1:.3f}, d_loss - {2:.6f}, acc - {3:.4f}, f1 - {4:.5f}, auc - {5:.4f}]'.\
            #     format(1.0*epoch, 100.0*batch/len(train_loader), d_loss.item(), acc['avg'], f1[hparams.avg_mode], auc[hparams.avg_mode]))
        (val_auc, val_f1, val_acc, val_conf_mat, best_thresh), val_loss = validation(discriminator, epoch=epoch)

        for lbl in range(hparams.num_classes):
            fig = plot_cf(val_conf_mat[lbl])
            writer.add_figure('val_conf_{}'.format(hparams.id_to_class[lbl]), fig, global_step=epoch)
            plt.close(fig)
            writer.add_scalar('val_f1_{}'.format(hparams.id_to_class[lbl]), val_f1[lbl], global_step=epoch)
            writer.add_scalar('val_auc_{}'.format(hparams.id_to_class[lbl]), val_auc[lbl], global_step=epoch)
            writer.add_scalar('val_acc_{}'.format(hparams.id_to_class[lbl]), val_acc[lbl], global_step=epoch)
        writer.add_scalar('val_f1_{}'.format('micro'), val_f1['micro'], global_step=epoch)
        writer.add_scalar('val_auc_{}'.format('micro'), val_auc['micro'], global_step=epoch)
        writer.add_scalar('val_f1_{}'.format('macro'), val_f1['macro'], global_step=epoch)
        writer.add_scalar('val_auc_{}'.format('macro'), val_auc['macro'], global_step=epoch)
        writer.add_scalar('val_loss', val_loss, global_step=epoch)
        writer.add_scalar('val_f1', val_f1[hparams.avg_mode], global_step=epoch)
        writer.add_scalar('val_auc', val_auc[hparams.avg_mode], global_step=epoch)
        writer.add_scalar('val_acc', val_acc['avg'], global_step=epoch)
        scheduler_D.step(val_loss)
        writer.add_scalar('learning_rate', optimizer_D.param_groups[0]['lr'], global_step=epoch)

        torch.save({
            'epoch': epoch,
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            }, hparams.model+'.'+str(epoch))
        if best_valid_auc <= val_auc[hparams.avg_mode]:
            best_valid_auc = val_auc[hparams.avg_mode]
            for lbl in range(hparams.num_classes):
                fig = plot_cf(val_conf_mat[lbl])
                writer.add_figure('best_val_conf_{}'.format(hparams.id_to_class[lbl]), fig, global_step=epoch)
                plt.close(fig)
            torch.save({
                'epoch': epoch,
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                }, hparams.model+'.best')
            print('best model on validation set saved.')
        print('[Epoch - {0:.1f} ---> val_auc - {1:.4f}, current_lr - {2:.6f}, val_loss - {3:.4f}, best_val_auc - {4:.4f}, val_acc - {5:.4f}, val_f1 - {6:.4f}] - time - {7:.1f}'\
            .format(1.0*epoch, val_auc[hparams.avg_mode], optimizer_D.param_groups[0]['lr'], val_loss, best_valid_auc, val_acc['avg'], val_f1[hparams.avg_mode], time.time()-start_time))
        start_time = time.time()
コード例 #9
0
def train(opt):
    """
    performs the training
    :param opt: all parameters inclusive network and opimizer are here combined
    """
    if torch.cuda.is_available() and opt.useCuda:
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)

    training_set, training_loader, eval_set, eval_loader = loadTrainEvalSet(
        opt)

    # log stuff
    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)
    writer = SummaryWriter(opt.log_path)
    if torch.cuda.is_available() and opt.useCuda:
        writer.add_graph(
            opt.model.cpu(),
            torch.rand(opt.batch_size, 3, opt.image_size, opt.image_size))
        opt.model.cuda()
    else:
        writer.add_graph(
            opt.model,
            torch.rand(opt.batch_size, 3, opt.image_size, opt.image_size))

    # write the hyperparams lr, batchsize  and imagesize in the tensorboard file
    writer.add_text(
        'Hyperparams', 'lr: {}, \nbatchsize: {}, \nimg_size:{}'.format(
            opt.learning_rate, opt.batch_size, opt.image_size))

    # loss and optimize
    if opt.optimizer is None:
        opt.optimizer = torch.optim.Adam(opt.model.parameters(),
                                         lr=opt.learning_rate,
                                         betas=(opt.momentum, 0.999),
                                         weight_decay=opt.decay)

    epoch_len = len(training_loader)
    for epoch in range(opt.num_epoches):
        training_set.dataset.is_training = True
        print('num epoch: {:4d}'.format(epoch))
        opt.model.train()
        for img_nr, (gt, img) in enumerate(training_loader):
            if torch.cuda.is_available() and opt.useCuda:
                img = Variable(img.cuda(), requires_grad=True)
            else:
                img = Variable(img, requires_grad=True)
            opt.optimizer.zero_grad()
            logits = opt.model(img)
            loss, loss_coord, loss_conf = opt.criterion(logits, gt)
            writeLossToSummary(writer, 'Train', loss.item(), loss_coord.item(),
                               loss_conf.item(), epoch * epoch_len + img_nr)
            loss.backward()
            opt.optimizer.step()

        # eval stuff
        opt.model.eval()
        eval_set.dataset.is_training = False
        loss_ls = []
        loss_coord_ls = []
        loss_conf_ls = []
        all_ap = []
        all_ap1 = []
        for te_iter, te_batch in enumerate(eval_loader):
            te_label, te_image = te_batch
            num_sample = len(te_label)
            if torch.cuda.is_available() and opt.useCuda:
                te_image = te_image.cuda()
            with torch.no_grad():
                te_logits = opt.model(te_image)
                batch_loss, batch_loss_coord, batch_loss_conf = opt.criterion(
                    te_logits, te_label)
                for i in range(num_sample):
                    ap = get_ap(te_logits[i],
                                filter_non_zero_gt_without_id(te_label[i]),
                                opt.image_size, opt.image_size,
                                opt.model.anchors, .5)
                    ap1 = get_ap(te_logits[i],
                                 filter_non_zero_gt_without_id(te_label[i]),
                                 opt.image_size, opt.image_size,
                                 opt.model.anchors, .8)
                    if not np.isnan(ap):
                        all_ap.append(ap)
                    if not np.isnan(ap1):
                        all_ap1.append(ap1)
            loss_ls.append(batch_loss * num_sample)
            loss_coord_ls.append(batch_loss_coord * num_sample)
            loss_conf_ls.append(batch_loss_conf * num_sample)
        te_loss = sum(loss_ls) / eval_set.__len__()
        te_coord_loss = sum(loss_coord_ls) / eval_set.__len__()
        te_conf_loss = sum(loss_conf_ls) / eval_set.__len__()
        writer.add_scalar('Val/AP0.5', np.mean(np.array(all_ap)),
                          epoch * epoch_len)
        writer.add_scalar('Val/AP0.8', np.mean(np.array(all_ap1)),
                          epoch * epoch_len)
        writeLossToSummary(writer, 'Val', te_loss.item(), te_coord_loss.item(),
                           te_conf_loss.item(), epoch * epoch_len)

        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': opt.model.state_dict(),
                'optimizer_state_dict': opt.optimizer.state_dict()
            }, opt.log_path + '/snapshot{:04d}.tar'.format(epoch))
    writer.close()
コード例 #10
0
    'batch_size': config.batch_size,
    'log path': str(log_path),
    'loss': str(config.loss),
    'score fnc': str(opt.score_fc),
    'trans_n_layers': config.trans_n_layers,
    'trans_n_head': config.trans_n_head,
    'trans_d_k': config.trans_d_k,
    'trans_d_v': config.trans_d_v,
    'trans_d_model': config.trans_d_model,
    'trans_d_inner': config.trans_d_inner,
    'trans_dropout': config.trans_dropout,
    '2channel': config.is_two_channel,
}
lera.log_hyperparams(global_par_dict)
for item in list(global_par_dict.keys()):
    writer.add_text(item, str(global_par_dict[item]))


def train(epoch):
    global e, updates, total_loss, start_time, report_total, report_correct, total_loss_sgm, total_loss_ss
    e = epoch
    model.train()
    SDR_SUM = np.array([])
    SDRi_SUM = np.array([])

    if updates <= config.warmup:  #如果不在warm阶段就正常规划
        pass
    elif config.schedule and scheduler.get_lr()[0] > 4e-5:
        scheduler.step()
        print(
            ("Decaying learning rate to %g" % scheduler.get_lr()[0], updates))
コード例 #11
0
def set_up_logging(parser, experiment_name, output_folder, quiet, args_dict, debug, **kwargs):
    """
    Set up a logger for the experiment

    Parameters
    ----------
    parser : parser
        The argument parser
    experiment_name : string
        Name of the experiment. If not specify, accepted from command line.
    output_folder : string
        Path to where all experiment logs are stored.
    quiet : bool
        Specify whether to print log to console or only to text file
    debug : bool
        Specify the logging level
    args_dict : dict
        Contains the entire argument dictionary specified via command line.

    Returns
    -------
    log_folder : String
        The final logging folder tree
    writer : tensorboardX.writer.SummaryWriter
        The tensorboard writer object. Used to log values on file for the tensorboard visualization.
    """
    LOG_FILE = 'logs.txt'

    # Experiment name override
    if experiment_name is None:
        experiment_name = input("Experiment name:")

    # Recover dataset name
    dataset = os.path.basename(os.path.normpath(kwargs['dataset_folder']))

    """
    We extract the TRAIN parameters names (such as model_name, lr, ... ) from the parser directly. 
    This is a somewhat risky operation because we access _private_variables of parsers classes.
    However, within our context this can be regarded as safe. 
    Shall we be wrong, a quick fix is writing a list of possible parameters such as:
    
        train_param_list = ['model_name','lr', ...] 
    
    and manually maintain it (boring!).
    
    Resources:
    https://stackoverflow.com/questions/31519997/is-it-possible-to-only-parse-one-argument-groups-parameters-with-argparse
    """

    # Fetch all non-default parameters
    non_default_parameters = []

    for group in parser._action_groups[2:]:
        if group.title not in ['GENERAL', 'DATA']:
            for action in group._group_actions:
                if (kwargs[action.dest] is not None) and (kwargs[action.dest] != action.default) and action.dest != 'load_model':
                    non_default_parameters.append(str(action.dest) + "=" + str(kwargs[action.dest]))

    # Build up final logging folder tree with the non-default training parameters
    log_folder = os.path.join(*[output_folder, experiment_name, dataset, *non_default_parameters,
                                '{}'.format(time.strftime('%d-%m-%y-%Hh-%Mm-%Ss'))])
    if not os.path.exists(log_folder):
        os.makedirs(log_folder)

    # Setup logging
    root = logging.getLogger()
    log_level = logging.DEBUG if debug else logging.INFO
    root.setLevel(log_level)
    format = "[%(asctime)s] [%(levelname)8s] --- %(message)s (%(filename)s:%(lineno)s)"
    date_format = '%Y-%m-%d %H:%M:%S'

    if os.isatty(2):
        cformat = '%(log_color)s' + format
        formatter = colorlog.ColoredFormatter(cformat, date_format,
                                              log_colors={
                                                  'DEBUG': 'cyan',
                                                  'INFO': 'white',
                                                  'WARNING': 'yellow',
                                                  'ERROR': 'red',
                                                  'CRITICAL': 'red,bg_white',
                                              })
    else:
        formatter = logging.Formatter(format, date_format)

    if not quiet:
        ch = logging.StreamHandler()
        ch.setFormatter(formatter)
        root.addHandler(ch)

    fh = logging.FileHandler(os.path.join(log_folder, LOG_FILE))
    fh.setFormatter(logging.Formatter(format, date_format))
    root.addHandler(fh)

    logging.info('Setup logging. Log file: {}'.format(os.path.join(log_folder, LOG_FILE)))

    # Save args to logs_folder
    logging.info('Arguments saved to: {}'.format(os.path.join(log_folder, 'args.txt')))
    with open(os.path.join(log_folder, 'args.txt'), 'w') as f:
        f.write(json.dumps(args_dict))

    # Define Tensorboard SummaryWriter
    logging.info('Initialize Tensorboard SummaryWriter')

    # Add all parameters to Tensorboard
    writer = SummaryWriter(log_dir=log_folder)
    writer.add_text('Args', json.dumps(args_dict))

    return log_folder, writer
コード例 #12
0
                                    num_workers=args.workers,
                                    pin_memory=True)

        if args.ckpt:
            pass
        else:
            # save graph and clips_order samples
            for data in train_dataloader:
                #tuple_clips, tuple_orders, tuple_clips_random, tuple_orders_random,idx = data
                tuple_clips, tuple_orders, idx = data
                for i in range(args.tl):
                    writer.add_video('train/tuple_clips',
                                     tuple_clips[:, i, :, :, :, :],
                                     i,
                                     fps=8)
                    writer.add_text('train/tuple_orders',
                                    str(tuple_orders[:, i].tolist()), i)
                tuple_clips = tuple_clips.to(device)
                #writer.add_graph(tcg, tuple_clips)
                break
            # save init params at step 0
            for name, param in tcg.named_parameters():
                writer.add_histogram('params/{}'.format(name), param, 0)

        n_data = train_dataset.__len__()

        ### loss funciton, optimizer and scheduler ###
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(tcg.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.wd)
コード例 #13
0
def main(opts):
    # Set parameters
    p = OrderedDict()  # Parameters to include in report
    p["trainBatch"] = opts.batch  # Training batch size
    testBatch = 1  # Testing batch size
    useTest = True  # See evolution of the test set when training
    nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
    snapshot = 1  # Store a model every snapshot epochs
    p["nAveGrad"] = 1  # Average the gradient of several iterations
    p["lr"] = opts.lr  # Learning rate
    p["wd"] = 5e-4  # Weight decay
    p["momentum"] = 0.9  # Momentum
    p["epoch_size"] = opts.step  # How many epochs to change learning rate
    p["num_workers"] = opts.numworker
    model_path = opts.pretrainedModel
    backbone = "xception"  # Use xception or resnet as feature extractor
    nEpochs = opts.epochs

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split("/")[-1]
    runs = glob.glob(os.path.join(save_dir_root, "run", "run_*"))
    for r in runs:
        run_id = int(r.split("_")[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0
    save_dir = os.path.join(save_dir_root, "run", "run_" + str(max_id))

    # Network definition
    if backbone == "xception":
        net_ = deeplab_xception_universal.deeplab_xception_end2end_3d(
            n_classes=20,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
            middle_classes=18,
        )
    elif backbone == "resnet":
        # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    modelName = (
        "deeplabv3plus-" + backbone + "-voc" + datetime.now().strftime("%b%d_%H-%M-%S")
    )
    criterion = ut.cross_entropy2d

    if gpu_id >= 0:
        # torch.cuda.set_device(device=gpu_id)
        net_.cuda()

    # net load weights
    if not model_path == "":
        x = torch.load(model_path)
        net_.load_state_dict_new(x)
        print("load pretrainedModel.")
    else:
        print("no pretrainedModel.")

    if not opts.loadmodel == "":
        x = torch.load(opts.loadmodel)
        net_.load_source_model(x)
        print("load model:", opts.loadmodel)
    else:
        print("no trained model load !!!!!!!!")

    log_dir = os.path.join(
        save_dir,
        "models",
        datetime.now().strftime("%b%d_%H-%M-%S") + "_" + socket.gethostname(),
    )
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text("load model", opts.loadmodel, 1)
    writer.add_text("setting", sys.argv[0], 1)

    # Use the following optimizer
    optimizer = optim.SGD(
        net_.parameters(), lr=p["lr"], momentum=p["momentum"], weight_decay=p["wd"]
    )

    composed_transforms_tr = transforms.Compose(
        [tr.RandomSized_new(512), tr.Normalize_xception_tf(), tr.ToTensor_()]
    )

    composed_transforms_ts = transforms.Compose(
        [tr.Normalize_xception_tf(), tr.ToTensor_()]
    )

    composed_transforms_ts_flip = transforms.Compose(
        [tr.HorizontalFlip(), tr.Normalize_xception_tf(), tr.ToTensor_()]
    )

    all_train = cihp_pascal_atr.VOCSegmentation(
        split="train", transform=composed_transforms_tr, flip=True
    )
    voc_val = pascal.VOCSegmentation(split="val", transform=composed_transforms_ts)
    voc_val_flip = pascal.VOCSegmentation(
        split="val", transform=composed_transforms_ts_flip
    )

    num_cihp, num_pascal, num_atr = all_train.get_class_num()
    ss = sam.Sampler_uni(num_cihp, num_pascal, num_atr, opts.batch)
    # balance datasets based pascal
    ss_balanced = sam.Sampler_uni(
        num_cihp, num_pascal, num_atr, opts.batch, balance_id=1
    )

    trainloader = DataLoader(
        all_train,
        batch_size=p["trainBatch"],
        shuffle=False,
        num_workers=p["num_workers"],
        sampler=ss,
        drop_last=True,
    )
    trainloader_balanced = DataLoader(
        all_train,
        batch_size=p["trainBatch"],
        shuffle=False,
        num_workers=p["num_workers"],
        sampler=ss_balanced,
        drop_last=True,
    )
    testloader = DataLoader(
        voc_val, batch_size=testBatch, shuffle=False, num_workers=p["num_workers"]
    )
    testloader_flip = DataLoader(
        voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p["num_workers"]
    )

    num_img_tr = len(trainloader)
    num_img_balanced = len(trainloader_balanced)
    num_img_ts = len(testloader)
    running_loss_tr = 0.0
    running_loss_tr_atr = 0.0
    running_loss_ts = 0.0
    aveGrad = 0
    global_step = 0
    print("Training Network")
    net = torch.nn.DataParallel(net_)

    id_list = torch.LongTensor(range(opts.batch))
    pascal_iter = int(num_img_tr // opts.batch)

    # Get graphs
    train_graph, test_graph = get_graphs(opts)
    adj1, adj2, adj3, adj4, adj5, adj6 = train_graph
    adj1_test, adj2_test, adj3_test, adj4_test, adj5_test, adj6_test = test_graph

    # Main Training and Testing Loop
    for epoch in range(resume_epoch, int(1.5 * nEpochs)):
        start_time = timeit.default_timer()

        if epoch % p["epoch_size"] == p["epoch_size"] - 1 and epoch < nEpochs:
            lr_ = ut.lr_poly(p["lr"], epoch, nEpochs, 0.9)
            optimizer = optim.SGD(
                net_.parameters(), lr=lr_, momentum=p["momentum"], weight_decay=p["wd"]
            )
            print("(poly lr policy) learning rate: ", lr_)
            writer.add_scalar("data/lr_", lr_, epoch)
        elif epoch % p["epoch_size"] == p["epoch_size"] - 1 and epoch > nEpochs:
            lr_ = ut.lr_poly(p["lr"], epoch - nEpochs, int(0.5 * nEpochs), 0.9)
            optimizer = optim.SGD(
                net_.parameters(), lr=lr_, momentum=p["momentum"], weight_decay=p["wd"]
            )
            print("(poly lr policy) learning rate: ", lr_)
            writer.add_scalar("data/lr_", lr_, epoch)

        net_.train()
        if epoch < nEpochs:
            for ii, sample_batched in enumerate(trainloader):
                inputs, labels = sample_batched["image"], sample_batched["label"]
                dataset_lbl = sample_batched["pascal"][0].item()
                # Forward-Backward of the mini-batch
                inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
                global_step += 1

                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()

                if dataset_lbl == 0:
                    # 0 is cihp -- target
                    _, outputs, _ = net.forward(
                        None,
                        input_target=inputs,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                elif dataset_lbl == 1:
                    # pascal is source
                    outputs, _, _ = net.forward(
                        inputs,
                        input_target=None,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                else:
                    # atr
                    _, _, outputs = net.forward(
                        None,
                        input_target=None,
                        input_middle=inputs,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                # print(sample_batched['pascal'])
                # print(outputs.size(),)
                # print(labels)
                loss = criterion(outputs, labels, batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_tr == (num_img_tr - 1):
                    running_loss_tr = running_loss_tr / num_img_tr
                    writer.add_scalar("data/total_loss_epoch", running_loss_tr, epoch)
                    print("[Epoch: %d, numImages: %5d]" % (epoch, epoch))
                    print("Loss: %f" % running_loss_tr)
                    running_loss_tr = 0
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time) + "\n")

                # Backward the averaged gradient
                loss /= p["nAveGrad"]
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p["nAveGrad"] == 0:
                    writer.add_scalar("data/total_loss_iter", loss.item(), global_step)
                    if dataset_lbl == 0:
                        writer.add_scalar(
                            "data/total_loss_iter_cihp", loss.item(), global_step
                        )
                    if dataset_lbl == 1:
                        writer.add_scalar(
                            "data/total_loss_iter_pascal", loss.item(), global_step
                        )
                    if dataset_lbl == 2:
                        writer.add_scalar(
                            "data/total_loss_iter_atr", loss.item(), global_step
                        )
                    optimizer.step()
                    optimizer.zero_grad()
                    # optimizer_gcn.step()
                    # optimizer_gcn.zero_grad()
                    aveGrad = 0

                # Show 10 * 3 images results each epoch
                if ii % (num_img_tr // 10) == 0:
                    grid_image = make_grid(
                        inputs[:3].clone().cpu().data, 3, normalize=True
                    )
                    writer.add_image("Image", grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(
                            torch.max(outputs[:3], 1)[1].detach().cpu().numpy()
                        ),
                        3,
                        normalize=False,
                        range=(0, 255),
                    )
                    writer.add_image("Predicted label", grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(
                            torch.squeeze(labels[:3], 1).detach().cpu().numpy()
                        ),
                        3,
                        normalize=False,
                        range=(0, 255),
                    )
                    writer.add_image("Groundtruth label", grid_image, global_step)

                print("loss is ", loss.cpu().item(), flush=True)
        else:
            # Balanced the number of datasets
            for ii, sample_batched in enumerate(trainloader_balanced):
                inputs, labels = sample_batched["image"], sample_batched["label"]
                dataset_lbl = sample_batched["pascal"][0].item()
                # Forward-Backward of the mini-batch
                inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
                global_step += 1

                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()

                if dataset_lbl == 0:
                    # 0 is cihp -- target
                    _, outputs, _ = net.forward(
                        None,
                        input_target=inputs,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                elif dataset_lbl == 1:
                    # pascal is source
                    outputs, _, _ = net.forward(
                        inputs,
                        input_target=None,
                        input_middle=None,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                else:
                    # atr
                    _, _, outputs = net.forward(
                        None,
                        input_target=None,
                        input_middle=inputs,
                        adj1_target=adj1,
                        adj2_source=adj2,
                        adj3_transfer_s2t=adj3,
                        adj3_transfer_t2s=adj3.transpose(2, 3),
                        adj4_middle=adj4,
                        adj5_transfer_s2m=adj5.transpose(2, 3),
                        adj6_transfer_t2m=adj6.transpose(2, 3),
                        adj5_transfer_m2s=adj5,
                        adj6_transfer_m2t=adj6,
                    )
                # print(sample_batched['pascal'])
                # print(outputs.size(),)
                # print(labels)
                loss = criterion(outputs, labels, batch_average=True)
                running_loss_tr += loss.item()

                # Print stuff
                if ii % num_img_balanced == (num_img_balanced - 1):
                    running_loss_tr = running_loss_tr / num_img_balanced
                    writer.add_scalar("data/total_loss_epoch", running_loss_tr, epoch)
                    print("[Epoch: %d, numImages: %5d]" % (epoch, epoch))
                    print("Loss: %f" % running_loss_tr)
                    running_loss_tr = 0
                    stop_time = timeit.default_timer()
                    print("Execution time: " + str(stop_time - start_time) + "\n")

                # Backward the averaged gradient
                loss /= p["nAveGrad"]
                loss.backward()
                aveGrad += 1

                # Update the weights once in p['nAveGrad'] forward passes
                if aveGrad % p["nAveGrad"] == 0:
                    writer.add_scalar("data/total_loss_iter", loss.item(), global_step)
                    if dataset_lbl == 0:
                        writer.add_scalar(
                            "data/total_loss_iter_cihp", loss.item(), global_step
                        )
                    if dataset_lbl == 1:
                        writer.add_scalar(
                            "data/total_loss_iter_pascal", loss.item(), global_step
                        )
                    if dataset_lbl == 2:
                        writer.add_scalar(
                            "data/total_loss_iter_atr", loss.item(), global_step
                        )
                    optimizer.step()
                    optimizer.zero_grad()

                    aveGrad = 0

                # Show 10 * 3 images results each epoch
                if ii % (num_img_balanced // 10) == 0:
                    grid_image = make_grid(
                        inputs[:3].clone().cpu().data, 3, normalize=True
                    )
                    writer.add_image("Image", grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(
                            torch.max(outputs[:3], 1)[1].detach().cpu().numpy()
                        ),
                        3,
                        normalize=False,
                        range=(0, 255),
                    )
                    writer.add_image("Predicted label", grid_image, global_step)
                    grid_image = make_grid(
                        ut.decode_seg_map_sequence(
                            torch.squeeze(labels[:3], 1).detach().cpu().numpy()
                        ),
                        3,
                        normalize=False,
                        range=(0, 255),
                    )
                    writer.add_image("Groundtruth label", grid_image, global_step)

                print("loss is ", loss.cpu().item(), flush=True)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(
                net_.state_dict(),
                os.path.join(
                    save_dir, "models", modelName + "_epoch-" + str(epoch) + ".pth"
                ),
            )
            print(
                "Save model at {}\n".format(
                    os.path.join(
                        save_dir, "models", modelName + "_epoch-" + str(epoch) + ".pth"
                    )
                )
            )

        # One testing epoch
        if useTest and epoch % nTestInterval == (nTestInterval - 1):
            val_pascal(
                net_=net_,
                testloader=testloader,
                testloader_flip=testloader_flip,
                test_graph=test_graph,
                criterion=criterion,
                epoch=epoch,
                writer=writer,
            )
コード例 #14
0
def main(args):
    systemInfo()

    dirs = os.listdir(args.root)
    if args.test:
        assert "tst" in dirs, "A 'tst' directory is not in {}".format(
            args.root)
    else:
        assert "train" in dirs, "A 'train' directory is not in {}".format(
            args.root)
        assert "val" in dirs, "A 'val' directory is not in {}".format(
            args.root)
        assert "labels" in dirs, "A 'labels' directory is not in {}".format(
            args.root)

    del dirs

    if args.is_cropped:
        assert args.crop_size[0] == args.crop_size[
            1], "Crop size is assumed to be square, but you supplied {}".format(
                args.crop_size)
        args.sample_size = args.crop_size[0]

    args.sample_duration = args.frames

    if args.output == "":
        now = datetime.datetime.now()
        args.output = os.path.join("./results",
                                   now.strftime("%Y-%m-%d_%H:%M:%S"))
        del now

    if not os.path.exists(args.output):
        os.mkdir(args.output)
        os.mkdir(os.path.join(args.output, "weights"))

    print("Output path: {}".format(args.output))

    with open(os.path.join(args.output, "Settings.txt"), "w") as outfile:
        outfile.write(str(vars(args)))

    print("Setting up Tensorboard")
    writer = SummaryWriter()
    writer.add_text('config', str(vars(args)))
    print("Tensorboard set up")

    print("Setting Pytorch cuda settings")
    torch.cuda.set_device(0)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = True
    print("Set Pytorch cuda settings\n")

    print("Creating model '{}'".format(args.arch))
    model = load_model(args)
    print("Model created\n")

    if args.gpu is not None:
        print("Using GPU {}\n".format(args.gpu))
        model = model.cuda(args.gpu)
    elif torch.cuda.device_count() == 1:
        print("Using a single GPU\n")
        model = model.cuda()
    else:
        print("Using {} GPUs\n".format(torch.cuda.device_count()))
        model = nn.DataParallel(model).cuda()

    print("Setting up loss and optimizer")
    if args.num_classes == 1:
        criterion = nn.BCELoss().cuda(args.gpu)
    else:
        criterion = nn.NLLLoss().cuda(args.gpu)

    optimizer = optim.SGD(model.parameters(),
                          args.lr,
                          momentum=args.momentum,
                          nesterov=args.nesterov,
                          weight_decay=args.weight_decay)

    scheduler = optim.lr_scheduler.StepLR(optimizer, args.step_size,
                                          args.gamma)
    print("Optimizer and loss function setup\n")

    best_accV = -1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            print("Loading checkpoint from epoch {} with val accuracy of {}".
                  format(checkpoint['epoch'], checkpoint['best_accV']))

            args.start_epoch = checkpoint['epoch']
            best_accV = checkpoint['best_accV']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})\n".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'\n".format(args.resume))

    if args.test:
        print("Initializing testing dataloaders")
        test_loader, test_batches, sampler = get_loader(args)
        tst_samples_per_epoch = test_batches * args.test_batchsize

        print("Test Batch size: {}\nTest batches: {}\nTest videos: {}".format(
            args.test_batchsize, test_batches, len(test_loader.files)))
        print('Dataloaders initialized\n')

        # evaluate on validation set
        timeT = test(test_loader, model, args)

    else:
        print("Initializing training dataloaders")
        train_loader, train_batches, val_loader, val_batches, sampler = get_loader(
            args)
        trn_samples_per_epoch = train_batches * args.batchsize
        val_samples_per_epoch = val_batches * args.val_batchsize
        print(args.root)
        print(
            "Trn Batch size: {}\nTrn batches: {}\nTrn videos: {}\nVal Batch size: {}\nVal batches: {}\nVal videos: {}\nTrn samples per epoch: {}\nVals samples per epoch: {}"
            .format(args.batchsize, train_batches,
                    len(train_loader.files), args.val_batchsize, val_batches,
                    len(val_loader.files), trn_samples_per_epoch,
                    val_samples_per_epoch))
        print('Dataloaders initialized\n')

        for epoch in range(args.start_epoch, args.epochs):
            _start = time.time()

            scheduler.step()
            writer.add_scalar('Learning Rate', optimizer.param_groups[0]["lr"],
                              epoch)

            # train for one epoch
            lossT, accT, timeT = train(train_loader, model, criterion,
                                       optimizer, epoch, writer, args)
            writer.add_scalar('Loss/Training-Avg', lossT, epoch)
            writer.add_scalar('Accuracy/Training', accT, epoch)
            writer.add_scalar('Time/Training-Avg', timeT, epoch)

            print("Epoch {} training completed: {}".format(
                epoch,
                datetime.datetime.now().isoformat()))
            print("Train time {}".format(timeT))
            time.sleep(1)

            # evaluate on validation set
            lossV, accV, timeV = validate(val_loader, model, criterion, args,
                                          epoch)
            writer.add_scalar('Loss/Validation-Avg', lossV, epoch)
            writer.add_scalar('Accuracy/Validation', accV, epoch)
            writer.add_scalar('Time/Validation-Avg', timeV, epoch)

            print("Epoch {} validation completed: {}".format(
                epoch,
                datetime.datetime.now().isoformat()))
            print("Val time {}".format(timeV))

            # remember best acc@1 and save checkpoint
            is_best = accV > best_accV
            best_accV = max(accV, best_accV)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_accV': best_accV,
                    'accV': accV,
                    'accT': accT,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                filename='checkpoint_{}.pth.tar'.format(epoch),
                dir=os.path.join(args.output, "weights"))

            _end = time.time()

            print(
                "Epoch {}\n\tTime: {} seconds\n\tTrain Loss: {}\n\tTrain Accuracy: {}\n\tValidation Loss: {}\n\tValidation Accuracy: {}\n"
                .format(epoch, _end - _start, lossT, accT, lossV, accV))

            print("Train time {}\nVal time {}".format(timeT, timeV))
コード例 #15
0
class Logger:
    """
    wrapper for the tensorboardx | does some higher level logging
    """
    def __init__(self,
                 steps_per_epoch,
                 model,
                 save_model_every_nth=100,
                 shared_model_path='.'):
        self.shared_model_path = shared_model_path
        self.loggin_path = "./logs/" + str(datetime.datetime.now())
        self.writer = SummaryWriter(self.loggin_path)
        self.steps_per_epoch = steps_per_epoch
        self.t = datetime.datetime.now()
        self.model = model
        self.save_model_every_nth = save_model_every_nth

    def log_values(self, epoch, values: dict = None):
        """
        logs the loss of a model, can also log arbitrary values but if the dict contains a key named 'loss' the value
        of this key is printed
        :param epoch: current epoch
        :param values: dict containing different dicts containing values (i.e.: {'loss':{'lossA': 10, 'lossB': 12}, ...}
        If the value of a key is a dict itself use the add_scalars function to log all these values to one graph.
        Be aware that this leads to multiple entries in the runs section of the tensorboard.
        Otherwise log the information into a separate graph.
        PGGAN logging is a good example:
        {'loss': {'lossG': g_loss_summed, ####These values are logged into one graph and printed to the commandline
                  'lossD': d_loss_summed},
        'info/WassersteinDistance': wasserstein_d_summed, ####These values are all logged to different graphs
        'info/eps': eps_summed,
        'info/FadeInFactor': fade_in_factor,
        'info/Level': self.resolution_level,
        'info/curr_level': cur_level}

        """
        if values is None:
            values = {}
        for k, v in values.items():

            if type(v) is dict:
                self.writer.add_scalars(k, v, epoch)
            else:
                self.writer.add_scalar(k, v, epoch)
            if k is 'loss':
                print(f"epoch: {epoch}" + json.dumps(v), end='\n')

    def log_fps(self, epoch):
        """
        logs the current fps
        :param epoch: current epoch
        """
        new_time = datetime.datetime.now()
        self.writer.add_scalar(
            "info/fps",
            self.steps_per_epoch * 1.0 / (new_time - self.t).total_seconds(),
            epoch)
        self.t = new_time

    def log_images(self, epoch, images, tag_name, columns):
        """
        logs images to the tensorboard
        :param epoch: current epoch
        :param images: list of images, scaled to 0-1
        :param tag_name: tag_name of images in tensorboard
        :param columns: number of columns to display the images
        """
        grid = vutils.make_grid(images,
                                normalize=True,
                                scale_each=False,
                                nrow=columns)
        self.writer.add_image(tag_name, grid, epoch)

    def save_model(self, epoch):
        if epoch % self.save_model_every_nth == 0:  # and epoch > 0:
            self.model.save_model(self.loggin_path)
            self.model.save_model(self.shared_model_path)

    def log_config(self, config):
        base_tag = 'config'
        # log batch size and device count
        text = f"batch size: {config.batch_size}\n\nnum gpus: {torch.cuda.device_count()}"
        self.writer.add_text(base_tag + "/hyperparameters", text)
        # log model
        text = str(self.model)
        self.writer.add_text(base_tag + "/model", text)
        # log complete config
        text = inspect.getsource(config)
        self.writer.add_text(base_tag + '/config',
                             '\t' + text.replace('\n', '\n\t'))

        # log rest of the config file for all params
        file = inspect.getfile(config)
        with open(file) as f:
            self.writer.add_text('rest', '\t' + f.read().replace('\n', '\n\t'))
コード例 #16
0
class SimpleModelLog:
    """For simple log.
    generate 4 kinds of log: 
    1. simple log.txt, all metric dicts are flattened to produce
    readable results.
    2. TensorBoard scalars and texts
    3. multi-line json file log.json.lst
    4. tensorboard_scalars.json, all scalars are stored in this file
        in tensorboard json format.
    """
    def __init__(self, model_dir):
        self.model_dir = Path(model_dir)
        self.log_file = None
        self.log_mjson_file = None
        self.summary_writter = None
        self.metrics = []
        self._text_current_gstep = -1
        self._tb_texts = []

    def open(self):
        model_dir = self.model_dir
        assert model_dir.exists()
        summary_dir = model_dir / 'summary'
        summary_dir.mkdir(parents=True, exist_ok=True)

        log_mjson_file_path = model_dir / f'log.json.lst'
        if log_mjson_file_path.exists():
            with open(log_mjson_file_path, 'r') as f:
                for line in f.readlines():
                    self.metrics.append(json.loads(line))
        log_file_path = model_dir / f'log.txt'
        self.log_mjson_file = open(log_mjson_file_path, 'a')
        #self.log_file = open(log_file_path, 'a')
        self.log_file = open(log_file_path, 'a+')
        self.summary_writter = SummaryWriter(str(summary_dir))
        return self

    def close(self):
        assert self.summary_writter is not None
        self.log_mjson_file.close()
        self.log_file.close()
        tb_json_path = str(self.model_dir / "tensorboard_scalars.json")
        self.summary_writter.export_scalars_to_json(tb_json_path)
        self.summary_writter.close()
        self.log_mjson_file = None
        self.log_file = None
        self.summary_writter = None

    def log_text(self, text, step, tag="regular log"):
        """This function only add text to log.txt and tensorboard texts
        """
        print(text)
        print(text, file=self.log_file)
        if step > self._text_current_gstep and self._text_current_gstep != -1:
            total_text = '\n'.join(self._tb_texts)
            self.summary_writter.add_text(tag, total_text, global_step=step)
            self._tb_texts = []
            self._text_current_gstep = step
        else:
            self._tb_texts.append(text)
        if self._text_current_gstep == -1:
            self._text_current_gstep = step

    def log_metrics(self, metrics: dict, step):
        flatted_summarys = flat_nested_json_dict(metrics, "/")
        for k, v in flatted_summarys.items():
            if isinstance(v, (list, tuple)):
                if any([isinstance(e, str) for e in v]):
                    continue
                v_dict = {str(i): e for i, e in enumerate(v)}
                for k1, v1 in v_dict.items():
                    self.summary_writter.add_scalar(k + "/" + k1, v1, step)
            else:
                if isinstance(v, str):
                    continue
                self.summary_writter.add_scalar(k, v, step)
        log_str = metric_to_str(metrics)
        print(log_str)
        print(log_str, file=self.log_file)
        print(json.dumps(metrics), file=self.log_mjson_file)
コード例 #17
0
ファイル: train.py プロジェクト: hlwang1124/AAFramework
                    interpolation=cv2.INTER_NEAREST),
                                      axis=0)

                conf_mat += confusion_matrix(gt, pred,
                                             valid_dataset.dataset.num_labels)
                losses = model.get_current_losses()
                valid_loss_iter.append(model.loss_segmentation)
                print('valid epoch {0:}, iters: {1:}/{2:} '.format(
                    epoch, epoch_iter,
                    len(valid_dataset) * valid_opt.batch_size),
                      end='\r')

        avg_valid_loss = torch.mean(torch.stack(valid_loss_iter))
        globalacc, pre, recall, F_score, iou = getScores(conf_mat)

        # Record performance on the validation set
        writer.add_scalar('valid/loss', avg_valid_loss, epoch)
        writer.add_scalar('valid/global_acc', globalacc, epoch)
        writer.add_scalar('valid/pre', pre, epoch)
        writer.add_scalar('valid/recall', recall, epoch)
        writer.add_scalar('valid/F_score', F_score, epoch)
        writer.add_scalar('valid/iou', iou, epoch)

        # Save the best model according to the F-score, and record corresponding epoch number in tensorboard
        if F_score > F_score_max:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save_networks('best')
            F_score_max = F_score
            writer.add_text('best model', str(epoch))
コード例 #18
0
ファイル: main.py プロジェクト: xuezu29/DALI
def main(args):

    if args.rank == 0:
        log.basicConfig(level=log.INFO)
        writer = SummaryWriter()
        writer.add_text('config', str(args))
    else:
        log.basicConfig(level=log.WARNING)
        writer = None

    torch.cuda.set_device(args.rank % args.world_size)
    torch.manual_seed(args.seed + args.rank)
    torch.cuda.manual_seed(args.seed + args.rank)
    torch.backends.cudnn.benchmark = True

    if args.world_size > 1:
        log.info('Initializing process group')
        dist.init_process_group(
            backend='nccl',
            init_method='tcp://' + args.ip + ':3567',
            world_size=args.world_size,
            rank=args.rank)
        log.info('Process group initialized')

    log.info('Initializing ' + args.loader + ' training dataloader...')
    train_loader, train_batches, sampler = get_loader(args, 'train')
    samples_per_epoch = train_batches * args.batchsize
    log.info('Dataloader initialized')

    model = VSRNet(args.frames, args.flownet_path, args.fp16)
    if args.fp16:
        network_to_half(model)
    model.cuda()
    model.train()
    for param in model.FlowNetSD_network.parameters():
        param.requires_grad = False

    model_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(model_params, lr=1, weight_decay=args.weight_decay)
    stepsize = 2 * train_batches
    clr_lambda = cyclic_learning_rate(args.min_lr, args.max_lr, stepsize)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[clr_lambda])
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)

    if args.world_size > 1:
        model = DistributedDataParallel(model)

    # TRAINING
    total_iter = 0
    while total_iter * args.world_size < args.max_iter:

        epoch = floor(total_iter / train_batches)

        # only if we are using DistributedSampler
        if args.world_size > 1 and args.loader == 'pytorch':
            sampler.set_epoch(epoch)

        model.train()
        total_epoch_loss = 0.0

        sample_timer = 0.0
        data_timer = 0.0
        compute_timer = 0.0

        iter_start = time.perf_counter()

        training_data_times = []
        training_start = datetime.datetime.now()

        # TRAINING EPOCH LOOP
        for i, inputs in enumerate(train_loader):
            training_stop = datetime.datetime.now()
            dataloading_time = training_stop - training_start
            training_data_times.append(dataloading_time.total_seconds() * 1000.0)

            if args.loader == 'DALI':
                inputs = inputs[0]["data"]
                # Needed? It is already gpu
                inputs = inputs.cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)
                if args.fp16:
                    inputs = inputs.half()

            if args.timing:
                torch.cuda.synchronize()
                data_end = time.perf_counter()


            optimizer.zero_grad()

            im_out = total_iter % args.image_freq == 0
            # writer.add_graph(model, inputs)
            loss = model(Variable(inputs), i, writer, im_out)

            total_epoch_loss += loss.item()

            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()

            optimizer.step()
            scheduler.step()

            if args.rank == 0:
                if args.timing:
                    torch.cuda.synchronize()
                    iter_end = time.perf_counter()
                    sample_timer += (iter_end - iter_start)
                    data_duration = data_end - iter_start
                    data_timer += data_duration
                    compute_timer += (iter_end - data_end)
                    torch.cuda.synchronize()
                    iter_start = time.perf_counter()
                writer.add_scalar('learning_rate', scheduler.get_lr()[0], total_iter)
                writer.add_scalar('train_loss', loss.item(), total_iter)

            log.info('Rank %d, Epoch %d, Iteration %d of %d, loss %.5f' %
                    (args.rank, epoch, i+1, train_batches, loss.item()))

            if total_iter % 100 == 0:
                print("Avg dataloading time: " + str(reduce(lambda x, y: x + y, training_data_times) / len(training_data_times)) + "ms")

            total_iter += 1
            if total_iter > args.max_iter:
                break

            training_start = datetime.datetime.now()

        if args.rank == 0:
            if args.timing:
                sample_timer_avg = sample_timer / samples_per_epoch
                writer.add_scalar('sample_time', sample_timer_avg, total_iter)
                data_timer_avg = data_timer / samples_per_epoch
                writer.add_scalar('sample_data_time', data_timer_avg, total_iter)
                compute_timer_avg = compute_timer / samples_per_epoch
                writer.add_scalar('sample_compute_time', compute_timer_avg, total_iter)
            epoch_loss_avg = total_epoch_loss / train_batches
            log.info('Rank %d, epoch %d: %.5f' % (args.rank, epoch, epoch_loss_avg))



        ### VALIDATION
        log.info('Initializing ' + args.loader + ' validation dataloader...')
        val_loader, val_batches, sampler = get_loader(args, 'val')
        model.eval()
        total_loss = 0
        total_psnr = 0
        for i, inputs in enumerate(val_loader):
            if args.loader == 'DALI':
                inputs = inputs[0]["data"]
                # Needed? It is already gpu
                inputs = inputs.cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)
                if args.fp16:
                    inputs = inputs.half()

            log.info('Validation it %d of %d' % (i + 1, val_batches))
            loss, psnr = model(Variable(inputs), i, None)
            total_loss += loss.item()
            total_psnr += psnr.item()

        loss = total_loss / i
        psnr = total_psnr / i

        if args.rank == 0:
            writer.add_scalar('val_loss', loss, total_iter)
            writer.add_scalar('val_psnr', psnr, total_iter)
        log.info('Rank %d validation loss %.5f' % (args.rank, loss))
        log.info('Rank %d validation psnr %.5f' % (args.rank, psnr))
コード例 #19
0
def worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.distributed:
        args.seed += args.gpu
        torch.cuda.set_device(args.gpu)

        args.rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + args.gpu

        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:8632',
            world_size=args.world_size,
            rank=args.rank)
    else:
        args.rank = 0

    args.use_cuda_env = args.use_cuda_env and torch.cuda.is_available()
    args.no_cuda_train = not torch.cuda.is_available()
    args.verbose = args.verbose and (args.rank == 0)

    env_device = torch.device(
        'cuda', args.gpu) if args.use_cuda_env else torch.device('cpu')
    train_device = torch.device('cuda', args.gpu) if (
        args.no_cuda_train == False) else torch.device('cpu')

    # Setup
    np.random.seed(args.seed)
    torch.manual_seed(np.random.randint(1, 10000))
    if args.use_cuda_env or (args.no_cuda_train == False):
        torch.cuda.manual_seed(random.randint(1, 10000))

    if train_device.type == 'cuda':
        print('Train:\n' + cuda_device_str(train_device.index), flush=True)

    if args.use_openai:
        test_env = create_vectorize_atari_env(args.env_name,
                                              args.seed,
                                              args.evaluation_episodes,
                                              episode_life=False,
                                              clip_rewards=False)
        test_env.reset()
    else:
        test_env = AtariEnv(args.env_name,
                            args.evaluation_episodes,
                            color_mode='gray',
                            device='cpu',
                            rescale=True,
                            clip_rewards=False,
                            episodic_life=False,
                            repeat_prob=0.0,
                            frameskip=4)

    # Agent
    dqn = Agent(args, test_env.action_space)

    # Construct validation memory
    if args.rank == 0:
        print('Initializing evaluation memory with {} entries...'.format(
            args.evaluation_size),
              end='',
              flush=True)
        start_time = time.time()

    val_mem = initialize_validation(args, train_device)

    if args.rank == 0:
        print('complete ({})'.format(format_time(time.time() - start_time)),
              flush=True)

    if args.evaluate:
        dqn.eval()
        rewards, lengths, avg_Q = test(args, 0, dqn, val_mem, test_env,
                                       train_device)  # Test
    else:
        if args.rank == 0:
            print('Entering main training loop', flush=True)

            if args.output_filename:
                csv_file = open(args.output_filename, 'w', newline='')
                csv_file.write(json.dumps(vars(args)))
                csv_file.write('\n')
                csv_writer = csv.writer(csv_file, delimiter=',')
                csv_writer.writerow([
                    'frames', 'total_time', 'rmean', 'rmedian', 'rstd', 'rmin',
                    'rmax', 'lmean', 'lmedian', 'lstd', 'lmin', 'lmax'
                ])
            else:
                csv_writer, csv_file = None, None

            if args.plot:
                from tensorboardX import SummaryWriter
                current_time = datetime.now().strftime('%b%d_%H-%M-%S')
                log_dir = os.path.join(
                    args.log_dir, current_time + '_' + socket.gethostname())
                writer = SummaryWriter(log_dir=log_dir)
                for k, v in vars(args).items():
                    writer.add_text(k, str(v))

            # Environment
            print('Initializing environments...', end='', flush=True)
            start_time = time.time()

        if args.use_openai:
            train_env = create_vectorize_atari_env(
                args.env_name,
                args.seed,
                args.num_ales,
                episode_life=True,
                clip_rewards=args.reward_clip,
                max_frames=args.max_episode_length)
            observation = torch.from_numpy(train_env.reset()).squeeze(1)
        else:
            train_env = AtariEnv(args.env_name,
                                 args.num_ales,
                                 color_mode='gray',
                                 device=env_device,
                                 rescale=True,
                                 clip_rewards=args.reward_clip,
                                 episodic_life=True,
                                 repeat_prob=0.0)
            train_env.train()
            observation = train_env.reset(
                initial_steps=args.ale_start_steps,
                verbose=args.verbose).clone().squeeze(-1)

        if args.rank == 0:
            print('complete ({})'.format(format_time(time.time() -
                                                     start_time)),
                  flush=True)

        # These variables are used to compute average rewards for all processes.
        episode_rewards = torch.zeros(args.num_ales,
                                      device=train_device,
                                      dtype=torch.float32)
        episode_lengths = torch.zeros(args.num_ales,
                                      device=train_device,
                                      dtype=torch.float32)
        final_rewards = torch.zeros(args.num_ales,
                                    device=train_device,
                                    dtype=torch.float32)
        final_lengths = torch.zeros(args.num_ales,
                                    device=train_device,
                                    dtype=torch.float32)
        has_completed = torch.zeros(args.num_ales,
                                    device=train_device,
                                    dtype=torch.uint8)

        mem = ReplayMemory(args, args.memory_capacity, train_device)
        mem.reset(observation)
        priority_weight_increase = (1 - args.priority_weight) / (
            args.t_max - args.learn_start)

        state = torch.zeros((args.num_ales, args.history_length, 84, 84),
                            device=mem.device,
                            dtype=torch.float32)
        state[:, -1] = observation.to(device=mem.device,
                                      dtype=torch.float32).div(255.0)

        num_frames_per_iter = args.num_ales
        total_steps = math.ceil(args.t_max /
                                (args.world_size * num_frames_per_iter))
        epsilons = np.linspace(
            args.epsilon_start, args.epsilon_final,
            math.ceil(args.epsilon_frames / num_frames_per_iter))
        epsilon_offset = math.ceil(args.learn_start / num_frames_per_iter)

        prefetcher = data_prefetcher(args.batch_size, train_device, mem)

        avg_loss = 'N/A'
        eval_offset = 0
        target_update_offset = 0

        total_time = 0
        env_time = 0
        mem_time = 0
        net_time = 0

        fps_steps = 0
        fps_start_time = time.time()

        # main loop
        iterator = range(total_steps)
        if args.rank == 0:
            iterator = tqdm(iterator)

        env_stream = torch.cuda.Stream()
        train_stream = torch.cuda.Stream()

        for update in iterator:

            T = args.world_size * update * num_frames_per_iter
            epsilon = epsilons[min(
                update - epsilon_offset,
                len(epsilons) - 1)] if T >= args.learn_start else epsilons[0]
            start_time = time.time()

            if update % args.replay_frequency == 0:
                dqn.reset_noise()  # Draw a new set of noisy weights

            dqn.eval()
            nvtx.range_push('train:select action')
            if args.noisy_linear:
                action = dqn.act(
                    state)  # Choose an action greedily (with noisy weights)
            else:
                action = dqn.act_e_greedy(state, epsilon=epsilon)
            nvtx.range_pop()
            dqn.train()

            fps_steps += 1

            if args.use_openai:
                action = action.cpu().numpy()

            torch.cuda.synchronize()

            with torch.cuda.stream(env_stream):
                nvtx.range_push('train:env step')
                if args.use_openai:
                    observation, reward, done, info = train_env.step(
                        action)  # Step
                    # convert back to pytorch tensors
                    observation = torch.from_numpy(observation).squeeze(1)
                    reward = torch.from_numpy(reward.astype(np.float32))
                    done = torch.from_numpy(done.astype(np.uint8))
                    action = torch.from_numpy(action)
                else:
                    observation, reward, done, info = train_env.step(
                        action, asyn=True)  # Step
                    observation = observation.clone().squeeze(-1)
                nvtx.range_pop()

                observation = observation.to(device=train_device)
                reward = reward.to(device=train_device)
                done = done.to(device=train_device)
                action = action.to(device=train_device)

                delta = time.time() - start_time
                env_time += delta
                total_time += delta

                observation = observation.float().div_(255.0)

                state[:, :-1].copy_(state[:, 1:].clone())
                state *= (1 - done).view(-1, 1, 1, 1).float()
                state[:, -1].copy_(observation)

                # update episodic reward counters
                not_done = (1 - done).float()
                has_completed |= (done == 1)

                episode_rewards += reward.float()
                final_rewards[done] = episode_rewards[done]
                episode_rewards *= not_done

                episode_lengths += not_done
                final_lengths[done] = episode_lengths[done]
                episode_lengths *= not_done

            # Train and test
            if T >= args.learn_start:
                mem.priority_weight = min(
                    mem.priority_weight + priority_weight_increase,
                    1)  # Anneal importance sampling weight β to 1
                prefetcher.preload()

                avg_loss = 0.0
                num_minibatches = min(
                    int(args.num_ales / args.replay_frequency), 8)
                for _ in range(num_minibatches):
                    # Sample transitions
                    start_time = time.time()
                    nvtx.range_push('train:sample states')
                    idxs, states, actions, returns, next_states, nonterminals, weights = prefetcher.next(
                    )
                    nvtx.range_pop()
                    delta = time.time() - start_time
                    mem_time += delta
                    total_time += delta

                    start_time = time.time()
                    nvtx.range_push('train:network update')
                    loss = dqn.learn(states, actions, returns, next_states,
                                     nonterminals, weights)
                    nvtx.range_pop()
                    delta = time.time() - start_time
                    net_time += delta
                    total_time += delta

                    start_time = time.time()
                    nvtx.range_push('train:update priorities')
                    mem.update_priorities(
                        idxs, loss)  # Update priorities of sampled transitions
                    nvtx.range_pop()
                    delta = time.time() - start_time
                    mem_time += delta
                    total_time += delta

                    avg_loss += loss.mean().item()
                avg_loss /= num_minibatches

                # Update target network
                if T >= target_update_offset:
                    dqn.update_target_net()
                    target_update_offset += args.target_update

            torch.cuda.current_stream().wait_stream(env_stream)
            torch.cuda.current_stream().wait_stream(train_stream)

            start_time = time.time()
            nvtx.range_push('train:append memory')
            mem.append(observation, action, reward,
                       done)  # Append transition to memory
            nvtx.range_pop()
            delta = time.time() - start_time
            mem_time += delta
            total_time += delta

            fps_end_time = time.time()
            fps = (args.world_size * fps_steps *
                   args.num_ales) / (fps_end_time - fps_start_time)
            fps_start_time = fps_end_time
            fps_steps = 0

            if args.rank == 0:
                if args.plot and ((update % args.replay_frequency) == 0):
                    writer.add_scalar('train/epsilon', epsilon, T)
                    writer.add_scalar('train/rewards', final_rewards.mean(), T)
                    writer.add_scalar('train/lengths', final_lengths.mean(), T)

                if T >= eval_offset:
                    eval_start_time = time.time()
                    dqn.eval()  # Set DQN (online network) to evaluation mode
                    rewards, lengths, avg_Q = test(args, T, dqn, val_mem,
                                                   test_env, train_device)
                    dqn.train(
                    )  # Set DQN (online network) back to training mode
                    eval_total_time = time.time() - start_time
                    eval_offset += args.evaluation_interval

                    rmean, rmedian, rstd, rmin, rmax = vec_stats(rewards)
                    lmean, lmedian, lstd, lmin, lmax = vec_stats(lengths)

                    print('reward: {:4.2f}, {:4.0f}, {:4.0f}, {:4.4f} | '
                          'length: {:4.2f}, {:4.0f}, {:4.0f}, {:4.4f} | '
                          'Avg. Q: {:4.4f} | {} | Overall FPS: {:4.2f}'.format(
                              rmean, rmin, rmax, rstd, lmean, lmin, lmax, lstd,
                              avg_Q, format_time(eval_total_time), fps),
                          flush=True)

                    if args.output_filename and csv_writer and csv_file:
                        csv_writer.writerow([
                            T, total_time, rmean, rmedian, rstd, rmin, rmax,
                            lmean, lmedian, lstd, lmin, lmax
                        ])
                        csv_file.flush()

                    if args.plot:
                        writer.add_scalar('eval/rewards', rmean, T)
                        writer.add_scalar('eval/lengths', lmean, T)
                        writer.add_scalar('eval/avg_Q', avg_Q, T)

                loss_str = '{:4.4f}'.format(avg_loss) if isinstance(
                    avg_loss, float) else avg_loss
                progress_data = 'T = {:,} epsilon = {:4.2f} avg reward = {:4.2f} loss: {} ({:4.2f}% net, {:4.2f}% mem, {:4.2f}% env)' \
                                .format(T, epsilon, final_rewards.mean().item(), loss_str, \
                                        *percent_time(total_time, net_time, mem_time, env_time))
                iterator.set_postfix_str(progress_data)

    if args.plot and (args.rank == 0):
        writer.close()

    if args.use_openai:
        train_env.close()
        test_env.close()
コード例 #20
0
ファイル: logger.py プロジェクト: johndpope/jukebox-1
class Logger:
    def __init__(self, logdir, rank):
        if rank == 0:
            from tensorboardX import SummaryWriter
            self.sw = SummaryWriter(f"{logdir}/logs")
        self.iters = 0
        self.rank = rank
        self.works = []
        self.logdir = logdir

    def step(self):
        self.iters += 1

    def flush(self):
        if self.rank == 0:
            self.sw.flush()

    def add_text(self, tag, text):
        if self.rank == 0:
            self.sw.add_text(tag, text, self.iters)

    def add_audios(self, tag, auds, sample_rate=22050, max_len=None, max_log=8):
        if self.rank == 0:
            for i in range(min(len(auds), max_log)):
                if max_len:
                    self.sw.add_audio(f"{i}/{tag}", auds[i][:max_len * sample_rate], self.iters, sample_rate)
                else:
                    self.sw.add_audio(f"{i}/{tag}", auds[i], self.iters, sample_rate)

    def add_audio(self, tag, aud, sample_rate=22050):
        if self.rank == 0:
            self.sw.add_audio(tag, aud, self.iters, sample_rate)

    def add_images(self, tag, img, dataformats="NHWC"):
        if self.rank == 0:
            self.sw.add_images(tag, img, self.iters, dataformats=dataformats)

    def add_image(self, tag, img):
        if self.rank == 0:
            self.sw.add_image(tag, img, self.iters)

    def add_scalar(self, tag, val):
        if self.rank == 0:
            self.sw.add_scalar(tag, val, self.iters)

    def get_range(self, loader):
        if self.rank == 0:
            self.trange = def_tqdm(loader)
        else:
            self.trange = loader
        return enumerate(self.trange)

    def close_range(self):
        if self.rank == 0:
            self.trange.close()

    def set_postfix(self, *args, **kwargs):
        if self.rank == 0:
            self.trange.set_postfix(*args, **kwargs)

    # For logging summaries of varies graph ops
    def add_reduce_scalar(self, tag, layer, val):
        if self.iters % 100 == 0:
            with t.no_grad():
                val = val.float().norm()/float(val.numel())
            work = dist.reduce(val, 0, async_op=True)
            self.works.append((tag, layer, val, work))

    def finish_reduce(self):
        for tag, layer, val, work in self.works:
            work.wait()
            if self.rank == 0:
                val = val.item()/dist.get_world_size()
                self.lw[layer].add_scalar(tag, val, self.iters)
        self.works = []
コード例 #21
0
def train(args):
    weight_dir = os.path.join(args.log_root, 'weights')
    log_dir = os.path.join(
        args.log_root, 'logs',
        'DS-Net-{}'.format(time.strftime("%Y-%m-%d-%H-%M-%S",
                                         time.localtime())))

    data_dir = os.path.join(args.data_root, args.dataset)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup DataLoader
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5963, 0.4372, 0.3824),
                             (0.2064, 0.1785, 0.1646))
    ])

    train_loader = SiameseNetworkDataLoader(data_dir,
                                            split='train',
                                            num_pair=args.train_pair,
                                            img_size=(args.img_row,
                                                      args.img_col),
                                            transform=transform)

    num_classes = train_loader.num_classes

    valid_loader = SiameseNetworkDataLoader(data_dir,
                                            split="val",
                                            num_pair=args.val_pair,
                                            img_size=(args.img_row,
                                                      args.img_col),
                                            transform=transform)

    tra_loader = data.DataLoader(train_loader,
                                 batch_size=args.batch_size,
                                 num_workers=int(multiprocessing.cpu_count() /
                                                 2),
                                 shuffle=True,
                                 collate_fn=train_loader.collate_fn)
    val_loader = data.DataLoader(valid_loader,
                                 batch_size=args.batch_size,
                                 num_workers=int(multiprocessing.cpu_count() /
                                                 2),
                                 shuffle=True,
                                 collate_fn=train_loader.collate_fn)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")
    model = DeepID2_ResNet()
    # model = DataParallelModel(model, device_ids=[0, 1, 2]).cuda()
    model = torch.nn.DataParallel(model, device_ids=[0]).cuda()

    # 2.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer
    if hasattr(model.module, 'optimizer'):
        print('> Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.learning_rate,
                                    momentum=0.90,
                                    weight_decay=5e-4,
                                    nesterov=True)
        # torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=1e3, norm_type=float('inf'))

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=4,
                                                gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    # 2.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    loss_contrastive = ContrastiveLoss()
    loss_CE = torch.nn.CrossEntropyLoss()

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 2. Model state init or resume...")
    args.start_epoch = 0
    beat_map = -100
    if args.resume is not None:
        full_path = os.path.join(weight_dir, args.resume)
        if os.path.isfile(full_path):
            print("> Loading model and optimizer from checkpoint '{}'".format(
                args.resume))

            checkpoint = torch.load(full_path)

            args.start_epoch = checkpoint['epoch']
            # beat_map = checkpoint['beat_map']
            model.load_state_dict(checkpoint['model_state'])  # weights
            optimizer.load_state_dict(
                checkpoint['optimizer_state'])  # gradient state
            del checkpoint

            print("> Loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

        else:
            print("> No checkpoint found at '{}'".format(full_path))
            raise Exception("> No checkpoint found at '{}'".format(full_path))
    else:
        # init_weights(model, pi=0.01,
        #              pre_trained="/home/pingguo/PycharmProject/dl_project/Weights/DS-Net/mobilenetv2.pth.tar")
        init_weights(model=model, pre_trained=None)

        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(
                args.pre_trained))
            full_path = os.path.join(weight_dir, args.pre_trained)

            pre_weight = torch.load(full_path)

            model_dict = model.state_dict()
            pretrained_dict = {
                k: v
                for k, v in pre_weight.items() if k in model_dict
            }

            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            # del model_dict
            model_dict = None
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4.0. Setup tensor-board for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.tensor_board:
        writer = SummaryWriter(log_dir=log_dir, comment="Face_Verification")
        dummy_input_1 = Variable(torch.rand(1, 3, args.img_row,
                                            args.img_col).cuda(),
                                 requires_grad=True)
        dummy_input_2 = Variable(torch.rand(1, 3, args.img_row,
                                            args.img_col).cuda(),
                                 requires_grad=True)
        # writer.add_graph(model, (dummy_input_1, dummy_input_2))

    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 3. Model Training start...")
    num_batches = int(
        math.ceil(tra_loader.dataset.num_pair / float(tra_loader.batch_size)))
    num_val = int(math.ceil(val_loader.dataset.num_pair))

    for epoch in np.arange(args.start_epoch, args.num_epochs):
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.1 Mini-Batch Training
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        model.train()
        pbar = tqdm(np.arange(num_batches))
        for train_i, (labels, images_1, images_2) in enumerate(
                tra_loader):  # One mini-Batch data, One iteration
            full_iter = (epoch * num_batches) + train_i + 1

            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" %
                                 (epoch + 1, args.num_epochs))

            images_1 = Variable(
                images_1.cuda(),
                requires_grad=True)  # Image feed into the deep neural network
            images_2 = Variable(images_2.cuda(), requires_grad=True)
            labels = Variable(labels.cuda(), requires_grad=False)

            optimizer.zero_grad()
            feats_1, feats_2, diff_preds = model(
                images_1, images_2)  # Here we have 3 output

            # !!!!!! Please Loss define !!!!!!
            targets = labels.type(torch.cuda.FloatTensor)
            loss_1 = loss_contrastive(feats_1, feats_2, targets)

            loss_2 = loss_CE(diff_preds, labels[:, 0])
            losses = 0.6 * loss_1 + 0.4 * loss_2
            losses.backward()  # back-propagation

            torch.nn.utils.clip_grad_norm(model.parameters(), 1e3)
            optimizer.step()  # parameter update based on the current gradient
            """
            if full_iter % 3000 == 0:
                state = model.state_dict()

                save_dir = os.path.join(weight_dir, "dsnet_model.pkl")
                torch.save(state, save_dir)
            """
            pbar.set_postfix(Losses=losses.data[0])

            # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
            # 4.1.1 Verbose training process
            # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
            if (train_i + 1) % args.verbose_interval == 0:
                # ---------------------------------------- #
                # 1. Training Losses
                # ---------------------------------------- #
                loss_log = "Epoch [%d/%d], Iter: %d Loss1: \t %.4f" % (
                    epoch + 1, args.num_epochs, train_i + 1, losses.data[0])

                # ---------------------------------------- #
                # 2. Training Metrics
                # ---------------------------------------- #
                cls_preds = F.softmax(diff_preds, dim=1)
                prec = accuracy(cls_preds, labels)
                pbar.set_postfix(Accuracy=prec.data[0])

                metric_log = "Epoch [%d/%d], Iter: %d, Acc: \t %.3f" % (
                    epoch + 1, args.num_epochs, train_i + 1, prec.data[0])

                logs = loss_log + metric_log
                if args.tensor_board:
                    writer.add_scalar('Training/Loss', losses.data[0],
                                      full_iter)
                    writer.add_scalar('Training/Accuracy', prec.data[0],
                                      full_iter)

                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        param_value = param.clone().cpu().data.numpy()
                        writer.add_histogram(name, param_value, full_iter)

        state = {
            "epoch": epoch + 1,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict()
        }

        save_dir = os.path.join(weight_dir, "faceVerification_model.pkl")
        torch.save(state, save_dir)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        model.eval()

        val_loss = 0.0
        acc_val = 0.0
        vali_count = 0
        start_time = time.time()
        while vali_count < num_val:
            for i_val, (labels_val, images_1_val,
                        images_2_val) in enumerate(val_loader):
                vali_count += 1

                images_1_val = Variable(images_1_val.cuda(), volatile=True)
                images_2_val = Variable(images_2_val.cuda(), volatile=True)
                labels_val = Variable(labels_val.cuda(), volatile=True)

                feats_1, feats_2, preds_val = model(
                    images_1_val,
                    images_2_val)  # Here we have 4 output for 4 loss

                # !!!!!! Please Loss define !!!!!!
                targets_val = labels_val.type(torch.cuda.FloatTensor)
                val_loss_1 = loss_contrastive(feats_1, feats_2, targets_val)
                val_loss_2 = loss_CE(preds_val, labels_val[:, 0])
                val_losses = 0.6 * val_loss_1 + 0.4 * val_loss_2
                val_loss += val_losses.data[0]

                # !!!!! Here calculate Metrics !!!!!
                # accumulating the confusion matrix and ious
                preds_val = F.softmax(preds_val, dim=1)
                prec_val = accuracy(preds_val, labels_val)
                acc_val += prec_val.data[0]

        print("Validation time: {}s".format(time.time() - start_time))
        # ---------------------------------------- #
        # 1. Validation Losses
        # ---------------------------------------- #
        val_loss /= vali_count
        acc_val /= vali_count

        loss_log = "Epoch [%d/%d], Loss: \t %.4f" % (epoch + 1,
                                                     args.num_epochs, val_loss)

        # ---------------------------------------- #
        # 2. Validation Metrics
        # ---------------------------------------- #
        metric_log = "Epoch [%d/%d], Acc: \t %.3f" % (epoch + 1,
                                                      args.num_epochs, acc_val)

        logs = loss_log + metric_log

        if args.tensor_board:
            writer.add_scalar('Validation/Loss', val_loss, epoch)
            writer.add_scalar('Validation/Accuracy', acc_val, epoch)

            writer.add_text('Validation/Text', logs, epoch)

            for name, param in model.named_parameters():
                writer.add_histogram(name,
                                     param.clone().cpu().data.numpy(), epoch)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.3 End of one Epoch
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # !!!!! Here choose suitable Metric for the best model selection !!!!!

        if acc_val >= beat_map:
            beat_map = acc_val
            state = {
                "epoch": epoch + 1,
                "beat_map": beat_map,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict()
            }

            save_dir = os.path.join(weight_dir,
                                    "faceVerification_best_model.pkl")
            torch.save(state, save_dir)

        # Note that step should be called after validate()
        scheduler.step()

        pbar.close()

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4.4 End of Training process
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    if args.tensor_board:
        # export scalar data to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
コード例 #22
0
class Engine(object):
    """
    Meta Engine for training & evaluating NCF model
    """
    def __init__(self, config):
        """
        Function to initialize the engine
        :param config: configuration dictionary
        """
        self.config = config  # model configuration
        self._metron = MetronAtK(top_k=10)  # Metrics for Top-10
        self._writer = SummaryWriter(log_dir='runs/{}'.format(
            config['alias']))  # Tensorboard Writer
        self._writer.add_text('config', str(config),
                              0)  # String output for Tensorboard Writer
        self.opt = use_optimizer(self.model, config)  # set optimizer

        # self.crit = torch.nn.MSELoss() # mean squared error loss for explicit feedback
        self.crit = torch.nn.BCELoss(
        )  # binary cross entropy loss for implicit feedback

    def train_single_batch(self, users, items, ratings):
        """
        Function to train a single batch with back-propagation
        :param users: user data
        :param items: item data
        :param ratings: rating data
        :return: Loss value
        """

        assert hasattr(self, 'model'), 'Please specify the exact model !'

        # if self.config['use_cuda'] is True:
        #     users, items, ratings = users.cuda(), items.cuda(), ratings.cuda()

        self.opt.zero_grad()
        ratings_pred = self.model(users, items)

        # Get the loss with the choice of pre-defined loss function
        loss = self.crit(ratings_pred.view(-1), ratings)
        # Back-propagate the loss
        loss.backward()
        # Optimize the loss
        self.opt.step()
        # Get the final loss
        loss = loss.item()

        return loss

    def train_an_epoch(self, train_loader, epoch_id):
        """
        Function to train a single epoch
        :param train_loader: a Loader class for the training data
        :param epoch_id: current epoch
        :return:
        """
        assert hasattr(self, 'model'), 'Please specify the exact model !'

        # Initialize training mode for current model
        self.model.train()
        # Initialize total loss
        total_loss = 0

        # Loop through batches in the training data
        for batch_id, batch in enumerate(train_loader):
            assert isinstance(batch[0], torch.LongTensor)

            # Get user, item, and rating data
            user, item, rating = batch[0], batch[1], batch[2]
            rating = rating.float()

            # Train a single batch
            loss = self.train_single_batch(user, item, rating)

            print('[Training Epoch {}] Batch {}, Loss {}'.format(
                epoch_id, batch_id, loss))
            # Add up total loss
            total_loss += loss

        # Save the loss values to be displayed on TensorBoard
        self._writer.add_scalar('model/loss', total_loss, epoch_id)

    def evaluate(self, evaluate_data, epoch_id):
        """
        Function eo evaluate the model on test data
        :param evaluate_data: data array to be evaluated
        :param epoch_id: current epoch
        :return: values of Hit Ratio and NDCG metrics
        """
        assert hasattr(self, 'model'), 'Please specify the exact model !'

        # Initialize evaluation mode for current model
        self.model.eval()

        # Use 'no_grad' to reduce the memory usage and speed up computations (no Gradient Calculation)
        with torch.no_grad():
            # Get test user and test item data
            test_users, test_items = evaluate_data[0], evaluate_data[1]
            # Get negative user and negative item data
            negative_users, negative_items = evaluate_data[2], evaluate_data[3]

            # if self.config['use_cuda'] is True:
            #     test_users = test_users.cuda()
            #     test_items = test_items.cuda()
            #
            #     negative_users = negative_users.cuda()
            #     negative_items = negative_items.cuda()

            # Calculate test scores
            test_scores = self.model(test_users, test_items)
            # Calculate negative scores
            negative_scores = self.model(negative_users, negative_items)

            # if self.config['use_cuda'] is True:
            #
            #     test_users = test_users.cpu()
            #     test_items = test_items.cpu()
            #     test_scores = test_scores.cpu()
            #
            #     negative_users = negative_users.cpu()
            #     negative_items = negative_items.cpu()
            #     negative_scores = negative_scores.cpu()

            self._metron.subjects = [
                test_users.data.view(-1).tolist(),
                test_items.data.view(-1).tolist(),
                test_scores.data.view(-1).tolist(),
                negative_users.data.view(-1).tolist(),
                negative_items.data.view(-1).tolist(),
                negative_scores.data.view(-1).tolist()
            ]

        # Calculate Hit Ratio and NDCG values
        hit_ratio, ndcg = self._metron.cal_hit_ratio(), self._metron.cal_ndcg()

        # Save the HR and NDCG values to be displayed on TensorBoard writer
        self._writer.add_scalar('performance/HR', hit_ratio, epoch_id)
        self._writer.add_scalar('performance/NDCG', ndcg, epoch_id)

        print('[Evaluating Epoch {}] HR = {:.4f}, NDCG = {:.4f}'.format(
            epoch_id, hit_ratio, ndcg))

        return hit_ratio, ndcg

    def save(self, alias, epoch_id, hit_ratio, ndcg):
        """
        Function to save information for every run
        :param alias: alias info
        :param epoch_id: current epoch
        :param hit_ratio: value of Hit Ratio metric
        :param ndcg: value of NDCG metric
        """
        assert hasattr(self, 'model'), 'Please specify the exact model !'

        # Choose the model directory where the model will be saved
        model_dir = self.config['model_dir'].format(alias, epoch_id, hit_ratio,
                                                    ndcg)
        # Save the model
        save_checkpoint(self.model, model_dir)
コード例 #23
0
ファイル: trainer.py プロジェクト: bmeatayi/neurogan
class TrainerCGAN(object):
    def __init__(self,
                 optimizer_g=torch.optim.Adam,
                 optimizer_d=torch.optim.Adam,
                 log_folder='results',
                 gan_mode='js',
                 lambda_gp=None,
                 grad_mode='gs',
                 gs_temp=None,
                 n_neuron=None,
                 gen_loss_mode='bce'):
        r"""
        Trainer class for conditional GAN
        Args:
            optimizer_g (torch.optim.Optimizer): optimizer of generator
            optimizer_d (torch.optim.Optimizer): optimizer of discriminator
            log_folder (str): path to the log folder
            gan_mode (str): mode of training
                            'js': Jensen-Shannon divergence
                            'wgan-gp': Wasserstein GAN with Gradient Penalty
                            'sn': Spectral Normalization (NOT IMPLEMENTED YET)
            lambda_gp (float): Gradient penalty scale factor (only for 'wgan-gp' gan_mode)
            grad_mode (str): Gradient estimator method
                                'gs': binary Gumbel-Softmax relaxation
                                'rebar': REBAR method
                                'reinforce': REINFORCE method
            gs_temp (float): Gumbel-Softmax temperature
            n_neuron (int): number of neurons
            gen_loss_mode (str): Generator loss function (options: 'bce' for binary cross entropy, 'hinge')
        """
        self.log_folder = log_folder
        self.optimizer_G = optimizer_g
        self.optimizer_D = optimizer_d

        assert gan_mode in ['js', 'wgan-gp',
                            'sn'], gan_mode + ' is not supported!'
        assert grad_mode in ['gs', 'rebar',
                             'reinforce'], grad_mode + ' is not supported!'
        assert gen_loss_mode in ['bce',
                                 'hinge'], gen_loss_mode + 'is not supported!'

        if gan_mode == 'wgan-gp':
            assert lambda_gp is not None, "lambda_gp is not given!"
            self.lambda_gp = lambda_gp

        if grad_mode == 'gs':
            assert gs_temp is not None, 'gs_temp is not given!'
            assert n_neuron is not None, 'n_unit is not given!'
            self.gumbel_softmax = GumbelSoftmaxBinary(n_unit=n_neuron,
                                                      gs_temp=gs_temp)
        elif grad_mode == 'reinforce':
            self.bernoulli_func = torch.distributions.bernoulli.Bernoulli
        elif grad_mode == 'rebar':
            self.rebar_estimator = Rebar()
            self.bernoulli_func = torch.distributions.bernoulli.Bernoulli

        self.gan_mode = gan_mode
        self.grad_mode = grad_mode
        self.gen_loss_mode = gen_loss_mode

        self.d_loss_history = []
        self.g_loss_history = []

        os.makedirs(log_folder, exist_ok=True)
        self.logger = SummaryWriter(log_folder)

    def _reset_loss_history(self):
        self.d_loss_history = []
        self.g_loss_history = []

    def train(self,
              generator,
              discriminator,
              train_loader,
              val_loader,
              lr=0.0002,
              b1=0.5,
              b2=0.999,
              log_interval=400,
              n_epochs=200,
              n_disc_train=5,
              temp_anneal=1):
        r"""
        train conditional GAN

        Args:
            generator (nn.module): Generator
            discriminator (nn.module): Discriminator
            train_loader (dataloader): train dataloader
            val_loader (dataloader): validation dataloader
            lr (float): Adam optimizer learning rate
            b1 (float): Adam optimizer beta1 parameter
            b2 (float): Adam optimizer beta2 parameter
            log_interval (int): iteration intervals for logging results
            n_epochs  (int): number of total epochs
            n_disc_train (int): train discriminator n_disc_train times vs. 1 train step of generator
            temp_anneal (float): annealing factor of Gumbel-Softmax temperature

        Returns:
            void
        """

        self.logger.add_text('G-Architecture', repr(generator))
        self.logger.add_text('D-Architecture', repr(discriminator))
        self.logger.add_text('GAN-mode', self.gan_mode)
        self.logger.add_text('Grad-mode', self.grad_mode)
        self._reset_loss_history()

        if torch.cuda.is_available():
            generator.cuda()
            discriminator.cuda()

        optim_g = self.optimizer_G(generator.parameters(),
                                   lr=lr,
                                   betas=(b1, b2))
        optim_d = self.optimizer_D(discriminator.parameters(),
                                   lr=lr,
                                   betas=(b1, b2))

        self.logger.add_text('G-optim', repr(optim_g))
        self.logger.add_text('D-optim', repr(optim_d))

        for epoch in range(n_epochs):
            for i, inputs in enumerate(train_loader):
                spike, stim = inputs
                batch_size = spike.shape[0]

                real_sample = spike.type(FloatTensor)
                stim = stim.type(FloatTensor)

                real_label = FloatTensor(batch_size, 1).fill_(1.0)
                fake_label = FloatTensor(batch_size, 1).fill_(0.0)

                if i % n_disc_train == 0:
                    # Train Generator
                    optim_g.zero_grad()
                    # discriminator.eval()
                    z = FloatTensor(
                        np.random.normal(0, 1,
                                         (batch_size, generator.latent_dim)))
                    fake_logits = generator(z, stim)
                    if self.grad_mode == 'rebar':
                        g_loss = self.rebar_estimator.step(
                            logits=fake_logits,
                            discriminator=discriminator,
                            stim=stim)
                    else:
                        fake_samples = self._logit2sample(fake_logits)
                        pred_fake = discriminator(fake_samples, stim)

                        g_loss = self._compute_g_loss(
                            fake_logits=fake_logits,
                            pred_fake=pred_fake,
                            fake_samples=fake_samples)
                        if self.grad_mode == 'reinforce':
                            fake_logits.backward(g_loss)
                            g_loss = g_loss.mean()
                        else:
                            g_loss.backward()
                        g_loss = g_loss.data.cpu().numpy()

                    optim_g.step()

                    self.g_loss_history.append(g_loss)

                # Train discriminator
                discriminator.train()
                generator.eval()
                optim_d.zero_grad()

                z = FloatTensor(
                    np.random.normal(0, 1, (batch_size, generator.latent_dim)))
                fake_logits = generator(z, stim)
                pred_real = discriminator(real_sample, stim)

                if self.gan_mode == 'wgan-gp':
                    pred_fake = discriminator(self._logit2sample(fake_logits),
                                              stim)
                    grad_penalty = self.compute_gp(discriminator, real_sample,
                                                   fake_logits, stim)
                    d_loss = torch.mean(pred_fake) - torch.mean(
                        pred_real) + self.lambda_gp * grad_penalty
                elif self.gan_mode == 'js' or self.gan_mode == 'sn':
                    pred_fake = discriminator(self._logit2sample(fake_logits),
                                              stim)
                    d_real_loss = F.binary_cross_entropy_with_logits(
                        pred_real, real_label)
                    d_fake_loss = F.binary_cross_entropy_with_logits(
                        pred_fake, fake_label)
                    d_loss = (d_real_loss + d_fake_loss) / 2

                d_loss.backward()
                optim_d.step()
                d_loss = d_loss.data.cpu().numpy()
                self.d_loss_history.append(d_loss)
                generator.train()

                print(
                    f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(train_loader)}] [D loss: {d_loss}] [G loss: {g_loss}]"
                )  # [Temp: {self.gumbel_softmax.temperature}]")
                self.logger.add_scalar('d_loss', d_loss)
                self.logger.add_scalar('g_loss', g_loss)
                batches_done = epoch * len(train_loader) + i
                if batches_done % log_interval == 0:
                    self.log_result(generator,
                                    discriminator,
                                    batches_done,
                                    val_loader=val_loader,
                                    is_save=True)

            # Temperature annealing
            if self.grad_mode == 'gs':
                self.gumbel_softmax.temperature *= temp_anneal
                # if self.gumbel_softmax.temperature < .005:
                #     self.gumbel_softmax.temperature == .005

            torch.save(generator, self.log_folder + 'generator.pt')
            torch.save(discriminator, self.log_folder + 'discriminator.pt')
            np.save(self.log_folder + 'g_loss.npy', g_loss)
            np.save(self.log_folder + 'd_loss.npy', d_loss)
            del spike, stim, inputs, real_sample, fake_logits, pred_fake, pred_real
            del d_real_loss, d_fake_loss

        self.log_result(generator,
                        discriminator,
                        batches_done,
                        val_loader=val_loader,
                        is_save=True)

        self.plot_loss_history()
        self.logger.export_scalars_to_json(self.log_folder +
                                           "./all_scalars.json")
        self.logger.close()
        torch.save(generator, self.log_folder + 'generator.pt')
        torch.save(discriminator, self.log_folder + 'discriminator.pt')
        np.save(self.log_folder + 'g_loss.npy', self.g_loss_history)
        np.save(self.log_folder + 'd_loss.npy', self.d_loss_history)

    def _logit2sample(self, fake_logits):
        r"""
        Converts logits to samples based on the gradient estimator method
        Args:
            fake_logits (torch.tensor): logits generated by generator

        Returns:
            sample (torch.tensor): binary or relaxed samples
        """
        if self.grad_mode == 'gs':
            return self.gumbel_softmax(fake_logits)
        elif self.grad_mode == 'reinforce' or self.grad_mode == 'rebar':
            self.sampler = self.bernoulli_func(logits=fake_logits)
            return self.sampler.sample()

    def compute_gp(self, discriminator, real_sample, fake_sample, stim):
        r"""
        Computes gradient penatly in WGAN-GP method
        Reference: Gulrajani et. al. (2017). Improved training of Wasserstein GANs.

        Args:
            discriminator (torch.nn.Module): Discriminator
            real_sample (torch.nn.Module): Real samples
            fake_sample (torch.tensor: Fake samples
            stim (torch.tensor): Stimulation

        Returns:
            gradient penalty (torch.tensor)
        """
        alpha = FloatTensor(np.random.rand(real_sample.size(0), 1, 1))
        ip = autograd.Variable(alpha * real_sample - (1 - alpha) * fake_sample,
                               requires_grad=True)
        disc_ip = discriminator(ip, stim)
        pre_grads = FloatTensor(real_sample.size(0), 1).fill_(1.0)

        grads = torch.autograd.grad(outputs=disc_ip,
                                    inputs=ip,
                                    grad_outputs=pre_grads,
                                    retain_graph=True,
                                    create_graph=True,
                                    only_inputs=True)[0]
        return ((grads.norm(2, dim=1) - 1)**2).mean()

    def _compute_g_loss(self, fake_logits, pred_fake, fake_samples):
        r"""
        Computes loss for the generator
        Args:
            pred_fake (torch.tensor): output of the discriminator (logit)
            fake_samples (torch.tensor): generated samples by the generator (discretized or relaxed version)

        Returns:
            g_loss (torch.tensor): loss value
        """
        if self.gen_loss_mode == 'bce':
            g_loss = -pred_fake.mean()
        elif self.gen_loss_mode == 'hinge':
            pass
            # TODO: Implement hinge loss

        if self.grad_mode == 'reinforce':
            log_probability = self.sampler.log_prob(fake_samples)
            d_log_probability = autograd.grad(
                log_probability,
                fake_logits,
                grad_outputs=torch.ones_like(log_probability),
                retain_graph=False)[0]
            g_loss = g_loss.detach() * d_log_probability.detach()

        return g_loss

    def generate_data(self,
                      generator,
                      discriminator,
                      val_loader,
                      n_sample=200,
                      is_save=False):
        generator.eval()
        discriminator.eval()

        fake_data = None  # torch.zeros([0, 995, generator.n_t, generator.n_cell])
        real_data = None  # torch.zeros([0, 995, generator.n_t, generator.n_cell])

        for j in range(n_sample):
            temp_gen = torch.zeros([0, generator.n_t, generator.n_cell])
            temp_real = torch.zeros([0, generator.n_t, generator.n_cell])
            for i, inputs in enumerate(val_loader):
                cnt, stim = inputs
                batch_size = cnt.shape[0]
                stim = stim.type(FloatTensor)
                z = FloatTensor(
                    np.random.normal(0, 1, (batch_size, generator.latent_dim)))
                fake_sample = self._logit2sample(generator(z, stim))
                if self.grad_mode == 'gs':
                    fake_sample[fake_sample >= .5] = 1
                    fake_sample[fake_sample < .5] = 0
                temp_gen = torch.cat((temp_gen, fake_sample.detach()))
                temp_real = torch.cat((temp_real, cnt.type(FloatTensor)))
            if fake_data is None:
                fake_data = torch.zeros(
                    [0, temp_gen.size(0), generator.n_t, generator.n_cell])
                real_data = torch.zeros(
                    [0, temp_gen.size(0), generator.n_t, generator.n_cell])

            fake_data = torch.cat((fake_data, temp_gen.unsqueeze(0)))
            real_data = torch.cat((real_data, temp_real.unsqueeze(0)))
            del temp_gen, temp_real

        fake_data = np.squeeze(fake_data.cpu().numpy())
        real_data = np.squeeze(real_data.cpu().numpy())
        if is_save:
            np.save(self.log_folder + 'fake_data.npy', fake_data)
            np.save(self.log_folder + 'real_data.npy', real_data)

        return fake_data, real_data

    def log_result(self,
                   generator,
                   discriminator,
                   batches_done,
                   val_loader,
                   n_sample=200,
                   is_save=False):

        fake_data, real_data = self.generate_data(generator,
                                                  discriminator,
                                                  val_loader,
                                                  n_sample,
                                                  is_save=is_save)

        pdf = PdfPages(self.log_folder + 'iter_' + str(batches_done) + '.pdf')
        if fake_data.ndim == 2:
            fake_data = fake_data[:, :, np.newaxis]
            real_data = real_data[:, :, np.newaxis]

        # Evaluation metrics
        viz = Visualize(real_data,
                        time=1)  # specify time for correct firing rates
        fig, ax = plt.subplots(figsize=(5, 5))
        viz.mean(fake_data, '')
        pdf.savefig(bbox_inches='tight')

        viz.std(fake_data, '')
        pdf.savefig(bbox_inches='tight')

        viz.corr(fake_data, model='')
        pdf.savefig(bbox_inches='tight')

        viz.noise_corr(fake_data, model='')
        pdf.savefig(bbox_inches='tight')

        # real_glm_filters = np.load('..//dataset//GLM_2D_10n_shared_noise//W.npy')
        # real_glm_biases = np.load('..//dataset//GLM_2D_10n_shared_noise//bias.npy')
        # real_w_shared_noise = .7
        # #
        # gen_glm_filters = generator.GLM.weight.detach().cpu().numpy().reshape(real_glm_filters.shape)
        # gen_glm_biases = generator.GLM.bias.detach().cpu().numpy()
        # gen_w_shared_noise = generator.shn_layer.weight.detach().cpu().numpy()
        #
        # fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        # ax[0].set_title('GLM filter parameters')
        # ax[0].plot([-1, 2], [-1, 2], 'black')
        # ax[0].plot(real_glm_filters.flatten(), np.flip(gen_glm_filters, axis=(1, 2)).flatten(), '.')
        # ax[0].plot([-real_w_shared_noise, real_w_shared_noise], [gen_w_shared_noise[0], gen_w_shared_noise[0]]
        #            , '*', markersize=5, label='Shared noise scale')
        # ax[1].set_title('GLM biases')
        # ax[1].plot([-4, -1], [-4, -1], 'black')
        # ax[1].plot(real_glm_biases, gen_glm_biases, '.')
        # pdf.savefig(bbox_inches='tight')
        #
        viz.mean_per_bin(fake_data,
                         'GAN 1',
                         neurons=[],
                         label='Neuron ',
                         figsize=[15, 10])
        pdf.savefig(bbox_inches='tight')

        # for i, spikes in enumerate(fake_data.transpose(2, 0, 1)):
        #     fig, ax = plt.subplots(2, 2, figsize=(20, 5))
        #     fig.suptitle('Neuron %i, iteration %i' % (i, batches_done))
        #     ax[0, 0].imshow(spikes)
        #     ax[0, 0].set_xlabel('time')
        #     ax[0, 0].set_ylabel('repetitions')
        #     ax[0, 0].set_xticks([])
        #     ax[0, 0].set_yticks([])
        #     ax[0, 0].set_title('GAN data')
        #
        #     ax[1, 0].imshow(real_data[:, :, i])
        #     ax[1, 0].set_xlabel('time')
        #     ax[1, 0].set_ylabel('repetitions')
        #     ax[1, 0].set_xticks([])
        #     ax[1, 0].set_yticks([])
        #     ax[1, 0].set_title('Real data')
        #
        #     ax[0, 1].plot(real_data[:, :, i].mean(axis=0), label='Real data')
        #     ax[0, 1].plot(spikes.mean(axis=0), label='GAN data')
        #     ax[0, 1].set_ylim(0, 1)
        #     ax[0, 1].set_xlabel('time')
        #     ax[0, 1].set_title('Mean firing rate')
        #     ax[0, 1].legend(loc=1)
        #
        #     ax[1, 1].plot(real_data[:, :, i].std(axis=0), label='Real data')
        #     ax[1, 1].plot(spikes.std(axis=0), label='GAN data')
        #     ax[1, 1].set_xlabel('time')
        #     ax[1, 1].set_title('Std of spike data')
        #     ax[1, 1].legend(loc=1)
        #     ax[1, 1].set_ylim(0, 1)
        #     plt.subplots_adjust(top=0.9, bottom=0.1, hspace=.8, wspace=0.2)
        #     pdf.savefig(fig)

        pdf.close()
        plt.close()

        # GLM_filters = generator.GLM.weight.detach().cpu().numpy()
        # N = GLM_filters.shape[0]
        # fig, ax = plt.subplots(1, N, figsize=(40, 5))
        # for i, f in enumerate(GLM_filters):
        #     ax[i].imshow(np.flip(f.reshape((30, 40)), axis=(0, 1)))
        #     ax[i].set_xlabel('x')
        #     ax[i].set_ylabel('t')
        #     ax[i].set_xticks([])
        #     ax[i].set_yticks([])
        #     ax[i].set_title('Neuron' + str(i))
        #
        # plt.savefig(self.log_folder + 'filt %i.jpg' % batches_done, dpi=120)
        # plt.close()

        # PLOT FILTERS
        # pdf = PdfPages(self.log_folder + 'filt_iter_' + str(batches_done) + '.pdf')
        # conv1filt = generator.conv1.weight.data.detach().cpu().numpy()
        # conv2filt = generator.conv2.weight.data.detach().cpu().numpy()
        # fcFilt = generator.fc.weight.detach().cpu().view(-1, 2, 24, 24).numpy()
        #
        # fig, ax = plt.subplots(*conv1filt.shape[0:2], figsize=(10, 5))
        # vmin, vmax = conv1filt.min(), conv1filt.max()
        # for (filtRow, axRow) in zip(conv1filt, ax):
        #     for (filt, axis) in zip(filtRow, axRow):
        #         axis.imshow(filt, vmin=vmin, vmax=vmax)
        # pdf.savefig(bbox_inches='tight')
        #
        # fig, ax = plt.subplots(*conv2filt.shape[0:2], figsize=(10, 5))
        # vmin, vmax = conv2filt.min(), conv2filt.max()
        # for (filtRow, axRow) in zip(conv2filt, ax):
        #     for (filt, axis) in zip(filtRow, axRow):
        #         axis.imshow(filt, vmin=vmin, vmax=vmax)
        # pdf.savefig(bbox_inches='tight')
        #
        # fig, ax = plt.subplots(*fcFilt.shape[0:2], figsize=(10, 20))
        # vmin, vmax = fcFilt.min(), fcFilt.max()
        # for (filtRow, axRow) in zip(fcFilt, ax):
        #     for (filt, axis) in zip(filtRow, axRow):
        #         axis.imshow(filt, vmin=vmin, vmax=vmax)
        # pdf.savefig(bbox_inches='tight')
        # pdf.close()
        plt.close()

        generator.train()
        discriminator.train()
        del real_data, fake_data

    def plot_loss_history(self):
        plotprop = PlotProps()
        fig = plotprop.init_figure(figsize=(14, 7))
        ax = plotprop.init_subplot(title='Loss history',
                                   tot_tup=(1, 1),
                                   sp_tup=(0, 0),
                                   xlabel='Iteration',
                                   ylabel='Value')

        ax.plot(np.arange(0, len(self.d_loss_history)),
                self.d_loss_history,
                linewidth=2.5,
                label='Discriminator loss')
        ax.plot(np.linspace(0, len(self.d_loss_history),
                            len(self.g_loss_history)),
                self.g_loss_history,
                linewidth=2.5,
                label='Generator loss')
        plotprop.legend()
        plt.savefig(self.log_folder + 'loss_history.jpg', dpi=200)
        plt.close()
コード例 #24
0
from lib.models.cdarts_controller import CDARTSController
from lib.utils.visualize import plot
from lib.utils import utils
from lib.core.search_function import search, retrain_warmup

from lib.config import SearchConfig
config = SearchConfig()

if 'cifar' in config.dataset:
    from lib.datasets.cifar import get_search_datasets
elif 'imagenet' in config.dataset:
    from lib.datasets.imagenet import get_search_datasets

# tensorboard
writer = SummaryWriter(log_dir=os.path.join(config.path, "tb"))
writer.add_text('config', config.as_markdown(), 0)

logger = utils.get_logger(
    os.path.join(config.path, "{}.log".format(config.name)))
if config.local_rank == 0:
    config.print_params(logger.info)

try:
    os.makedirs(config.retrain_path)
except:
    pass

if config.use_apex:
    import apex
    from apex.parallel import DistributedDataParallel as DDP
else:
コード例 #25
0
def main(epoch = 200, save_path = './checkpoint/', load_path = './checkpoint/KoGPT2_checkpoint_long.tar',
		data_file_path = 'dataset/lyrics_dataset.txt', batch_size = 8, summary_url = 'runs/', new = 0, text_size = 100):
	ctx = 'cuda'
	cachedir = '~/kogpt2/'
	summary = SummaryWriter(summary_url)

	pytorch_kogpt2 = {
		'url': 'https://kobert.blob.core.windows.net/models/kogpt2/pytorch/pytorch_kogpt2_676e9bcfa7.params',
		'fname': 'pytorch_kogpt2_676e9bcfa7.params',
		'chksum': '676e9bcfa7'
	}
	kogpt2_config = {
		"initializer_range": 0.02,
		"layer_norm_epsilon": 1e-05,
		"n_ctx": 1024,
		"n_embd": 768,
		"n_head": 12,
		"n_layer": 12,
		"n_positions": 1024,
		"vocab_size": 50000
	}

	# download model
	model_info = pytorch_kogpt2
	model_path = download(model_info['url'],
						   model_info['fname'],
						   model_info['chksum'],
						   cachedir=cachedir)
	# download vocab
	vocab_info = tokenizer
	vocab_path = download(vocab_info['url'],
						   vocab_info['fname'],
						   vocab_info['chksum'],
						   cachedir=cachedir)

	# KoGPT-2 언어 모델 학습을 위한 GPT2LMHeadModel 선언
	kogpt2model = GPT2LMHeadModel(config=GPT2Config.from_dict(kogpt2_config))

	# model_path 로부터 다운로드 받은 내용을 load_state_dict 으로 업로드
	kogpt2model.load_state_dict(torch.load(model_path))

	device = torch.device(ctx)
	kogpt2model.to(device)
	count = 0
	# 불러오기 부분
	try:
		checkpoint = torch.load(load_path, map_location=device)

		# KoGPT-2 언어 모델 학습을 위한 GPT2LMHeadModel 선언
		kogpt2model = GPT2LMHeadModel(config=GPT2Config.from_dict(kogpt2_config))
		kogpt2model.load_state_dict(checkpoint['model_state_dict'])

		kogpt2model.eval()
	except:
		print("count 0 : ", load_path)
	else:
		print("count check : ",re.findall("\d+", load_path))
		count = max([int(i) for i in (re.findall("\d+", load_path))])

	if new:
		count = 0
	# 추가로 학습하기 위해 .train() 사용
	kogpt2model.train()
	vocab_b_obj = gluonnlp.vocab.BERTVocab.from_sentencepiece(vocab_path,
														 mask_token=None,
														 sep_token=None,
														 cls_token=None,
														 unknown_token='<unk>',
														 padding_token='<pad>',
														 bos_token='<s>',
														 eos_token='</s>')

	tok_path = get_tokenizer()
	model, vocab = kogpt2model, vocab_b_obj
	sentencepieceTokenizer = SentencepieceTokenizer(tok_path)

	dataset = Read_Dataset(data_file_path, vocab, sentencepieceTokenizer)
	data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

	learning_rate = 3e-5
	criterion = torch.nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

	## train
	# vocab.token_to_idx["\n"] = vocab.token_to_idx["<unused0>"]
	# del vocab.token_to_idx["<unused0>"]
	# vocab.token_to_idx["<|endoftext|>"] = vocab.token_to_idx["<unused1>"]
	# del vocab.token_to_idx["<unused1>"]

	model = model.to(ctx)
	tok = SentencepieceTokenizer(tok_path)

	print('KoGPT-2 Transfer Learning Start')
	avg_loss = (0.0, 0.0)
	for epoch in range(epoch):
		for data in data_loader:
			optimizer.zero_grad()
			data = torch.stack(data) # list of Tensor로 구성되어 있기 때문에 list를 stack을 통해 변환해준다.
			data = data.transpose(1,0)
			data = data.to(ctx)
			model = model.to(ctx)

			outputs = model(data, labels=data)
			loss, logits = outputs[:2]
			loss = loss.to(ctx)
			loss.backward()
			avg_loss = (avg_loss[0] * 0.99 + loss, avg_loss[1] * 0.99 + 1.0)
			optimizer.step()

			if count % 10 == 0:
				print('epoch no.{0} train no.{1}  loss = {2:.5f} avg_loss = {3:.5f}' . format(epoch, count, loss, avg_loss[0] / avg_loss[1]))
				summary.add_scalar('loss/avg_loss', avg_loss[0] / avg_loss[1], count)
				summary.add_scalar('loss/loss', loss, count)
				# print("save")
				# torch.save({
				# 	'epoch': epoch,
				# 	'train_no': count,
				# 	'model_state_dict': model.state_dict(),
				# 	'optimizer_state_dict': optimizer.state_dict(),
				# 	'loss': loss
				# }, save_path + 'KoGPT2_checkpoint_' + str(count) + '.tar')

				#generator 진행
				if (count > 0 and count % 1000 == 0) or (len(data) < batch_size):
					sent = sample_sequence(model.to("cpu"), tok, vocab, sent="사랑", text_size=text_size, temperature=0.7, top_p=0.8, top_k=40)
					sent = sent.replace("<unused0>", "\n") # 비효율적이지만 엔터를 위해서 등장
					sent = auto_enter(sent)
					print(sent)
					summary.add_text('Text', sent, count)
					del sent
					pass

			#########################################
			if (count > 0 and count % 18500 == 0):
				# 모델 저장
				try:
					torch.save({
						'epoch': epoch,
						'train_no': count,
						'model_state_dict': model.state_dict(),
						'optimizer_state_dict': optimizer.state_dict(),
						'loss': loss
					}, save_path + 'KoGPT2_checkpoint_' + str(count) + '.tar')
				except:
					pass
			count += 1
コード例 #26
0
    def forward(self, x):
        x = preprocess_obs_fn(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

pg = Policy()
vf = Value()
optimizer = optim.Adam(list(pg.parameters()) + list(vf.parameters()), lr=args.learning_rate)
loss_fn = nn.MSELoss()

# TRY NOT TO MODIFY: start the game
experiment_name = f"{time.strftime('%b%d_%H-%M-%S')}__{args.exp_name}__{args.seed}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))
if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, tensorboard=True, config=vars(args), name=experiment_name)
    writer = SummaryWriter(f"/tmp/{experiment_name}")
global_step = 0
while global_step < args.total_timesteps:
    next_obs = np.array(env.reset())
    actions = np.empty((args.episode_length,), dtype=object)
    rewards, dones = np.zeros((2, args.episode_length))
    obs = np.empty((args.episode_length,) + env.observation_space.shape)
    
    # TODO: put other storage logic here
    values = torch.zeros((args.episode_length))
    neglogprobs = torch.zeros((args.episode_length,))
    entropys = torch.zeros((args.episode_length,))
コード例 #27
0
def worker(gpu, ngpus_per_node, callback, args):
    args.gpu = gpu

    if args.distributed:
        args.seed += args.gpu
        torch.cuda.set_device(args.gpu)

        args.rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + args.gpu

        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:8632',
            world_size=args.world_size,
            rank=args.rank)
    else:
        args.rank = 0

    if (args.num_ales % args.num_minibatches) != 0:
        raise ValueError(
            'Number of ales({}) size is not even divisible by the minibatch size({})'
            .format(args.num_ales, args.num_minibatches))

    if args.num_steps_per_update == -1:
        args.num_steps_per_update = args.num_steps

    minibatch_size = int(args.num_ales / args.num_minibatches)
    step0 = args.num_steps - args.num_steps_per_update
    n_minibatch = -1

    args.use_cuda_env = args.use_cuda_env and torch.cuda.is_available()
    args.no_cuda_train = (not args.no_cuda_train) and torch.cuda.is_available()
    args.verbose = args.verbose and (args.rank == 0)

    env_device = torch.device(
        'cuda', args.gpu) if args.use_cuda_env else torch.device('cpu')
    train_device = torch.device('cuda', args.gpu) if (
        args.no_cuda_train == False) else torch.device('cpu')

    np.random.seed(args.seed)
    torch.manual_seed(np.random.randint(1, 10000))
    if args.use_cuda_env or (args.no_cuda_train == False):
        torch.cuda.manual_seed(np.random.randint(1, 10000))

    if args.rank == 0:
        if args.output_filename:
            train_csv_file = open(args.output_filename, 'w', newline='')
            train_csv_file.write(json.dumps(vars(args)))
            train_csv_file.write('\n')
            train_csv_writer = csv.writer(train_csv_file, delimiter=',')
            train_csv_writer.writerow([
                'frames', 'fps', 'total_time', 'rmean', 'rmedian', 'rmin',
                'rmax', 'lmean', 'lmedian', 'lmin', 'lmax', 'entropy',
                'value_loss', 'policy_loss'
            ])

            eval_output_filename = '.'.join([
                ''.join(args.output_filename.split('.')[:-1] + ['_test']),
                'csv'
            ])
            eval_csv_file = open(eval_output_filename, 'w', newline='')
            eval_csv_file.write(json.dumps(vars(args)))
            eval_csv_file.write('\n')
            eval_csv_writer = csv.writer(eval_csv_file, delimiter=',')
            eval_csv_writer.writerow([
                'frames', 'total_time', 'rmean', 'rmedian', 'rmin', 'rmax',
                'rstd', 'lmean', 'lmedian', 'lmin', 'lmax', 'lstd'
            ])
        else:
            train_csv_file, train_csv_writer = None, None
            eval_csv_file, eval_csv_writer = None, None

        if args.plot:
            from tensorboardX import SummaryWriter
            current_time = datetime.now().strftime('%b%d_%H-%M-%S')
            log_dir = os.path.join(args.log_dir,
                                   current_time + '_' + socket.gethostname())
            writer = SummaryWriter(log_dir=log_dir)
            for k, v in vars(args).items():
                writer.add_text(k, str(v))

        print()
        print('PyTorch  : {}'.format(torch.__version__))
        print('CUDA     : {}'.format(torch.backends.cudnn.m.cuda))
        print('CUDNN    : {}'.format(torch.backends.cudnn.version()))
        print('APEX     : {}'.format('.'.join(
            [str(i) for i in apex.amp.__version__.VERSION])))
        print()

    if train_device.type == 'cuda':
        print(cuda_device_str(train_device.index), flush=True)

    if args.use_openai:
        train_env = create_vectorize_atari_env(
            args.env_name,
            args.seed,
            args.num_ales,
            episode_life=args.episodic_life,
            clip_rewards=False,
            max_frames=args.max_episode_length)
        observation = torch.from_numpy(train_env.reset()).squeeze(1)
    else:
        train_env = AtariEnv(args.env_name,
                             args.num_ales,
                             color_mode='gray',
                             repeat_prob=0.0,
                             device=env_device,
                             rescale=True,
                             episodic_life=args.episodic_life,
                             clip_rewards=False,
                             frameskip=4)
        train_env.train()
        observation = train_env.reset(initial_steps=args.ale_start_steps,
                                      verbose=args.verbose).squeeze(-1)

    if args.use_openai_test_env:
        test_env = create_vectorize_atari_env(args.env_name,
                                              args.seed,
                                              args.evaluation_episodes,
                                              episode_life=False,
                                              clip_rewards=False)
        test_env.reset()
    else:
        test_env = AtariEnv(args.env_name,
                            args.evaluation_episodes,
                            color_mode='gray',
                            repeat_prob=0.0,
                            device='cpu',
                            rescale=True,
                            episodic_life=False,
                            clip_rewards=False,
                            frameskip=4)

    model = ActorCritic(args.num_stack,
                        train_env.action_space,
                        normalize=args.normalize,
                        name=args.env_name)
    model = model.to(train_device).train()

    if args.rank == 0:
        print(model)
        args.model_name = model.name

    if args.use_adam:
        optimizer = optim.Adam(model.parameters(), lr=args.lr, amsgrad=True)
    else:
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=args.lr,
                                  eps=args.eps,
                                  alpha=args.alpha)

    # This is the number of frames GENERATED between two updates
    num_frames_per_iter = args.num_ales * args.num_steps_per_update
    total_steps = math.ceil(args.t_max /
                            (args.world_size * num_frames_per_iter))
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=args.opt_level,
                                      loss_scale=args.loss_scale)

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

    shape = (args.num_steps + 1, args.num_ales, args.num_stack,
             *train_env.observation_space.shape[-2:])
    states = torch.zeros(shape, device=train_device, dtype=torch.float32)
    states[step0, :, -1] = observation.to(device=train_device,
                                          dtype=torch.float32)

    shape = (args.num_steps + 1, args.num_ales)
    values = torch.zeros(shape, device=train_device, dtype=torch.float32)
    logits = torch.zeros(
        (args.num_steps + 1, args.num_ales, train_env.action_space.n),
        device=train_device,
        dtype=torch.float32)
    returns = torch.zeros(shape, device=train_device, dtype=torch.float32)

    shape = (args.num_steps, args.num_ales)
    rewards = torch.zeros(shape, device=train_device, dtype=torch.float32)
    masks = torch.zeros(shape, device=train_device, dtype=torch.float32)
    actions = torch.zeros(shape, device=train_device, dtype=torch.long)

    mus = torch.ones(shape, device=train_device, dtype=torch.float32)
    # pis = torch.zeros(shape, device=train_device, dtype=torch.float32)
    rhos = torch.zeros((args.num_steps, minibatch_size),
                       device=train_device,
                       dtype=torch.float32)

    # These variables are used to compute average rewards for all processes.
    episode_rewards = torch.zeros(args.num_ales,
                                  device=train_device,
                                  dtype=torch.float32)
    final_rewards = torch.zeros(args.num_ales,
                                device=train_device,
                                dtype=torch.float32)
    episode_lengths = torch.zeros(args.num_ales,
                                  device=train_device,
                                  dtype=torch.float32)
    final_lengths = torch.zeros(args.num_ales,
                                device=train_device,
                                dtype=torch.float32)

    if args.use_gae:
        raise ValueError('GAE is not compatible with VTRACE')

    maybe_npy = lambda a: a.numpy() if args.use_openai else a

    torch.cuda.synchronize()

    iterator = range(total_steps)
    if args.rank == 0:
        iterator = tqdm(iterator)
        total_time = 0
        evaluation_offset = 0

    for update in iterator:

        T = args.world_size * update * num_frames_per_iter
        if (args.rank == 0) and (T >= evaluation_offset):
            evaluation_offset += args.evaluation_interval
            eval_lengths, eval_rewards = evaluate(args, T, total_time, model,
                                                  test_env, eval_csv_writer,
                                                  eval_csv_file)

            if args.plot:
                writer.add_scalar('eval/rewards_mean',
                                  eval_rewards.mean().item(),
                                  T,
                                  walltime=total_time)
                writer.add_scalar('eval/lengths_mean',
                                  eval_lengths.mean().item(),
                                  T,
                                  walltime=total_time)

        start_time = time.time()

        with torch.no_grad():

            for step in range(args.num_steps_per_update):
                nvtx.range_push('train:step')
                value, logit = model(states[step0 + step])

                # store values and logits
                values[step0 + step] = value.squeeze(-1)

                # convert actions to numpy and perform next step
                probs = torch.clamp(F.softmax(logit, dim=1),
                                    min=0.00001,
                                    max=0.99999)
                probs_action = probs.multinomial(1).to(env_device)
                # Check if the multinomial threw an exception
                # https://github.com/pytorch/pytorch/issues/7014
                torch.cuda.current_stream().synchronize()
                observation, reward, done, info = train_env.step(
                    maybe_npy(probs_action))

                if args.use_openai:
                    # convert back to pytorch tensors
                    observation = torch.from_numpy(observation)
                    reward = torch.from_numpy(reward)
                    done = torch.from_numpy(done.astype(np.uint8))
                else:
                    observation = observation.squeeze(-1).unsqueeze(1)

                # move back to training memory
                observation = observation.to(device=train_device)
                reward = reward.to(device=train_device, dtype=torch.float32)
                done = done.to(device=train_device)
                probs_action = probs_action.to(device=train_device,
                                               dtype=torch.long)

                not_done = 1.0 - done.float()

                # update rewards and actions
                actions[step0 + step].copy_(probs_action.view(-1))
                masks[step0 + step].copy_(not_done)
                rewards[step0 + step].copy_(reward.sign())

                #mus[step0 + step] = F.softmax(logit, dim=1).gather(1, actions[step0 + step].view(-1).unsqueeze(-1)).view(-1)
                mus[step0 + step] = torch.clamp(F.softmax(logit, dim=1).gather(
                    1, actions[step0 + step].view(-1).unsqueeze(-1)).view(-1),
                                                min=0.00001,
                                                max=0.99999)

                # update next observations
                states[step0 + step + 1, :, :-1].copy_(states[step0 + step, :,
                                                              1:])
                states[step0 + step + 1] *= not_done.view(
                    -1, *[1] * (observation.dim() - 1))
                states[step0 + step + 1, :,
                       -1].copy_(observation.view(-1,
                                                  *states.size()[-2:]))

                # update episodic reward counters
                episode_rewards += reward
                final_rewards[done] = episode_rewards[done]
                episode_rewards *= not_done

                episode_lengths += not_done
                final_lengths[done] = episode_lengths[done]
                episode_lengths *= not_done
                nvtx.range_pop()

        n_minibatch = (n_minibatch + 1) % args.num_minibatches
        min_ale_index = int(n_minibatch * minibatch_size)
        max_ale_index = min_ale_index + minibatch_size

        # compute v-trace using the recursive method (remark 1 in IMPALA paper)
        # value_next_step, logit = model(states[-1:, min_ale_index:max_ale_index, :, : ,:].contiguous().view(-1, *states.size()[-3:]))
        # returns[-1, min_ale_index:max_ale_index] = value_next_step.squeeze()
        # for step in reversed(range(args.num_steps)):
        #     value, logit = model(states[step, min_ale_index:max_ale_index, :, : ,:].contiguous().view(-1, *states.size()[-3:]))
        #     pis = F.softmax(logit, dim=1).gather(1, actions[step, min_ale_index:max_ale_index].view(-1).unsqueeze(-1)).view(-1)
        #     c = torch.clamp(pis / mus[step, min_ale_index:max_ale_index], max=c_)
        #     rhos[step, :] = torch.clamp(pis / mus[step, min_ale_index:max_ale_index], max=rho_)
        #     delta_value = rhos[step, :] * (rewards[step, min_ale_index:max_ale_index] + (args.gamma * value_next_step - value).squeeze())
        #     returns[step, min_ale_index:max_ale_index] = value.squeeze() + delta_value + args.gamma * c * \
        #             (returns[step + 1, min_ale_index:max_ale_index] - value_next_step.squeeze())
        #     value_next_step = value

        nvtx.range_push('train:compute_values')
        value, logit = model(
            states[:, min_ale_index:max_ale_index, :, :, :].contiguous().view(
                -1,
                *states.size()[-3:]))
        batch_value = value.detach().view((args.num_steps + 1, minibatch_size))
        batch_probs = F.softmax(logit.detach()[:(args.num_steps *
                                                 minibatch_size), :],
                                dim=1)
        batch_pis = batch_probs.gather(
            1, actions[:, min_ale_index:max_ale_index].contiguous().view(
                -1).unsqueeze(-1)).view((args.num_steps, minibatch_size))
        returns[-1, min_ale_index:max_ale_index] = batch_value[-1]

        with torch.no_grad():
            for step in reversed(range(args.num_steps)):
                c = torch.clamp(batch_pis[step, :] /
                                mus[step, min_ale_index:max_ale_index],
                                max=args.c_hat)
                rhos[step, :] = torch.clamp(
                    batch_pis[step, :] /
                    mus[step, min_ale_index:max_ale_index],
                    max=args.rho_hat)
                delta_value = rhos[step, :] * (
                    rewards[step, min_ale_index:max_ale_index] +
                    (args.gamma * batch_value[step + 1] -
                     batch_value[step]).squeeze())
                returns[step, min_ale_index:max_ale_index] = \
                        batch_value[step, :].squeeze() + delta_value + args.gamma * c * \
                        (returns[step + 1, min_ale_index:max_ale_index] - batch_value[step + 1, :].squeeze())

        value = value[:args.num_steps * minibatch_size, :]
        logit = logit[:args.num_steps * minibatch_size, :]

        log_probs = F.log_softmax(logit, dim=1)
        probs = F.softmax(logit, dim=1)

        action_log_probs = log_probs.gather(
            1, actions[:, min_ale_index:max_ale_index].contiguous().view(
                -1).unsqueeze(-1))
        dist_entropy = -(log_probs * probs).sum(-1).mean()

        advantages = returns[:-1, min_ale_index:max_ale_index].contiguous(
        ).view(-1).unsqueeze(-1) - value

        value_loss = advantages.pow(2).mean()
        policy_loss = -(action_log_probs * rhos.view(-1, 1).detach() * \
                (rewards[:, min_ale_index:max_ale_index].contiguous().view(-1, 1) + args.gamma * \
                returns[1:, min_ale_index:max_ale_index].contiguous().view(-1, 1) - value).detach()).mean()
        nvtx.range_pop()

        nvtx.range_push('train:backprop')
        loss = value_loss * args.value_loss_coef + policy_loss - dist_entropy * args.entropy_coef
        optimizer.zero_grad()
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                       args.max_grad_norm)
        optimizer.step()
        nvtx.range_pop()

        nvtx.range_push('train:next_states')
        for step in range(0, args.num_steps_per_update):
            states[:-1, :, :, :, :] = states[1:, :, :, :, :]
            rewards[:-1, :] = rewards[1:, :]
            actions[:-1, :] = actions[1:, :]
            masks[:-1, :] = masks[1:, :]
            mus[:-1, :] = mus[1:, :]
        nvtx.range_pop()

        torch.cuda.synchronize()

        if args.rank == 0:
            iter_time = time.time() - start_time
            total_time += iter_time

            if args.plot:
                writer.add_scalar('train/rewards_mean',
                                  final_rewards.mean().item(),
                                  T,
                                  walltime=total_time)
                writer.add_scalar('train/lengths_mean',
                                  final_lengths.mean().item(),
                                  T,
                                  walltime=total_time)
                writer.add_scalar('train/value_loss',
                                  value_loss,
                                  T,
                                  walltime=total_time)
                writer.add_scalar('train/policy_loss',
                                  policy_loss,
                                  T,
                                  walltime=total_time)
                writer.add_scalar('train/entropy',
                                  dist_entropy,
                                  T,
                                  walltime=total_time)

            progress_data = callback(args, model, T, iter_time, final_rewards,
                                     final_lengths, value_loss, policy_loss,
                                     dist_entropy, train_csv_writer,
                                     train_csv_file)
            iterator.set_postfix_str(progress_data)

    if args.plot:
        writer.close()

    if args.use_openai:
        train_env.close()
    if args.use_openai_test_env:
        test_env.close()
コード例 #28
0
def train():
    if FLAGS.dataset == 'cifar10':
        dataset = datasets.CIFAR10(
            './data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                transforms.Lambda(lambda x: x + torch.rand_like(x) / 128)
            ]))
    if FLAGS.dataset == 'stl10':
        dataset = datasets.STL10(
            './data',
            split='unlabeled',
            download=True,
            transform=transforms.Compose([
                transforms.Resize((48, 48)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                transforms.Lambda(lambda x: x + torch.rand_like(x) / 128)
            ]))

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             drop_last=True)

    net_G = net_G_models[FLAGS.arch](FLAGS.z_dim).to(device)
    net_D = net_D_models[FLAGS.arch]().to(device)
    loss_fn = loss_fns[FLAGS.loss]()

    optim_G = optim.Adam(net_G.parameters(), lr=FLAGS.lr_G, betas=FLAGS.betas)
    optim_D = optim.Adam(net_D.parameters(), lr=FLAGS.lr_D, betas=FLAGS.betas)
    sched_G = optim.lr_scheduler.LambdaLR(
        optim_G, lambda step: 1 - step / FLAGS.total_steps)
    sched_D = optim.lr_scheduler.LambdaLR(
        optim_D, lambda step: 1 - step / FLAGS.total_steps)

    os.makedirs(os.path.join(FLAGS.logdir, 'sample'))
    writer = SummaryWriter(os.path.join(FLAGS.logdir))
    sample_z = torch.randn(FLAGS.sample_size, FLAGS.z_dim).to(device)
    with open(os.path.join(FLAGS.logdir, "flagfile.txt"), 'w') as f:
        f.write(FLAGS.flags_into_string())
    writer.add_text("flagfile",
                    FLAGS.flags_into_string().replace('\n', '  \n'))

    real, _ = next(iter(dataloader))
    grid = (make_grid(real[:FLAGS.sample_size]) + 1) / 2
    writer.add_image('real_sample', grid)

    looper = infiniteloop(dataloader)
    with trange(1, FLAGS.total_steps + 1, dynamic_ncols=True) as pbar:
        for step in pbar:
            # Discriminator
            for _ in range(FLAGS.n_dis):
                with torch.no_grad():
                    z = torch.randn(FLAGS.batch_size, FLAGS.z_dim).to(device)
                    fake = net_G(z).detach()
                real = next(looper).to(device)
                net_D_real = net_D(real)
                net_D_fake = net_D(fake)
                loss = loss_fn(net_D_real, net_D_fake)

                optim_D.zero_grad()
                loss.backward()
                optim_D.step()

                if FLAGS.loss == 'was':
                    loss = -loss
                pbar.set_postfix(loss='%.4f' % loss)
            writer.add_scalar('loss', loss, step)

            # Generator
            z = torch.randn(FLAGS.batch_size * 2, FLAGS.z_dim).to(device)
            loss = loss_fn(net_D(net_G(z)))

            optim_G.zero_grad()
            loss.backward()
            optim_G.step()

            sched_G.step()
            sched_D.step()
            pbar.update(1)

            if step == 1 or step % FLAGS.sample_step == 0:
                fake = net_G(sample_z).cpu()
                grid = (make_grid(fake) + 1) / 2
                writer.add_image('sample', grid, step)
                save_image(
                    grid, os.path.join(FLAGS.logdir, 'sample',
                                       '%d.png' % step))

            if step == 1 or step % FLAGS.eval_step == 0:
                torch.save(
                    {
                        'net_G': net_G.state_dict(),
                        'net_D': net_D.state_dict(),
                        'optim_G': optim_G.state_dict(),
                        'optim_D': optim_D.state_dict(),
                        'sched_G': sched_G.state_dict(),
                        'sched_D': sched_D.state_dict(),
                    }, os.path.join(FLAGS.logdir, 'model.pt'))
                if FLAGS.record:
                    imgs = generate_imgs(net_G, device, FLAGS.z_dim, 50000,
                                         FLAGS.batch_size)
                    is_score, fid_score = get_inception_and_fid_score(
                        imgs, device, FLAGS.fid_cache, verbose=True)
                    pbar.write("%s/%s Inception Score: %.3f(%.5f), "
                               "FID Score: %6.3f" %
                               (step, FLAGS.total_steps, is_score[0],
                                is_score[1], fid_score))
                    writer.add_scalar('inception_score', is_score[0], step)
                    writer.add_scalar('inception_score_std', is_score[1], step)
                    writer.add_scalar('fid_score', fid_score, step)
    writer.close()
コード例 #29
0
def train(args):
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu")  #Selects Torch Device
    split_train_val(
        args, per_val=args.per_val
    )  #Generate the train and validation sets for the model as text files:

    current_time = datetime.now().strftime(
        '%b%d_%H%M%S')  #Gets Current Time and Date
    log_dir = os.path.join(
        'runs', current_time +
        f"_{args.arch}_{args.model_name}")  #Greate the log directory
    writer = SummaryWriter(
        log_dir=log_dir)  #Initialize the tensorboard summary writer

    # Setup Augmentations
    if args.aug:  #if augmentation is true
        data_aug = Compose(
            [RandomRotate(10),
             RandomHorizontallyFlip(),
             AddNoise()])  #compose some augmentation functions
    else:
        data_aug = None

    loader = section_loader  #name the loader
    train_set = loader(
        is_transform=True, split='train', augmentations=data_aug
    )  #use custom data loader to get the training set (instance of the loader class)
    val_set = loader(
        is_transform=True,
        split='val')  #use custom made data  loader to get the validation

    n_classes = train_set.n_classes  #initalize the number of classes which is hard coded in the dataloader

    # Create sampler:

    shuffle = False  # must turn False if using a custom sampler
    with open(pjoin('data', 'splits', 'section_train.txt'), 'r') as f:
        train_list = f.read().splitlines(
        )  #load the section train list previously stored in a text file created by split_train_val() function
    with open(pjoin('data', 'splits', 'section_val.txt'), 'r') as f:
        val_list = f.read().splitlines(
        )  #load the section train list previously stored in a text file created by split_train_val() function

    class CustomSamplerTrain(torch.utils.data.Sampler
                             ):  #create a custom sampler
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x'
                    ]  #choose randomly between letter i and letter x
            self.indices = [
                idx for (idx, name) in enumerate(train_list) if char[0] in name
            ]  #choose index all inlines or all crosslines from the training list created by split_train_val() function
            return (self.indices[i] for i in torch.randperm(len(self.indices))
                    )  #shuffle the indices and return them

    class CustomSamplerVal(torch.utils.data.Sampler):
        def __iter__(self):
            char = ['i' if np.random.randint(2) == 1 else 'x'
                    ]  #choose randomly between letter i and letter x
            self.indices = [
                idx for (idx, name) in enumerate(val_list) if char[0] in name
            ]  #choose index all inlines or all crosslines from the validation list created by split_train_val() function
            return (self.indices[i] for i in torch.randperm(len(self.indices))
                    )  #shuffle the indices and return them

    trainloader = data.DataLoader(
        train_set, batch_size=args.batch_size, num_workers=12, shuffle=True
    )  #use pytorch data loader to get the batches of training set
    valloader = data.DataLoader(
        val_set, batch_size=args.batch_size, num_workers=12
    )  #use pytorch data loader to get the batches of validation set

    # Setup Metrics
    running_metrics = runningScore(
        n_classes
    )  #initialize class instance for evaluation metrics for training
    running_metrics_val = runningScore(
        n_classes
    )  #initialize class instance for evaluation meterics for validation

    # Setup Model
    if args.resume is not None:  #Check if we have a stored model or not
        if os.path.isfile(args.resume):  #if yes then load the stored model
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            model = torch.load(args.resume)
        else:
            print("No checkpoint found at '{}'".format(
                args.resume))  #if stored model requested with invalid path
    else:  #if  no stord model then load the requested model
        #n_classes=64
        model = get_model(name=args.arch,
                          pretrained=args.pretrained,
                          batch_size=args.batch_size,
                          growth_rate=32,
                          drop_rate=0,
                          n_classes=n_classes)  #get the stored model

    model = torch.nn.DataParallel(
        model, device_ids=range(
            torch.cuda.device_count()))  #Use as many GPUs as we can
    model = model.to(device)  # Send to GPU

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        print('Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
            amsgrad=True,
            weight_decay=args.weight_decay,
            eps=args.eps
        )  #if no specified optimizer then load the defualt optimizer

    loss_fn = core.loss.focal_loss2d  #initialize a function loss function

    if args.class_weights:  #if class weights are to be used then intailize them
        # weights are inversely proportional to the frequency of the classes in the training set
        class_weights = torch.tensor(
            [0.7151, 0.8811, 0.5156, 0.9346, 0.9683, 0.9852],
            device=device,
            requires_grad=False)
    else:
        class_weights = None  #if no class weights then no need to use them

    best_iou = -100.0
    class_names = [
        'null', 'upper_ns', 'middle_ns', 'lower_ns', 'rijnland_chalk',
        'scruff', 'zechstein'
    ]  #initialize the name of different classes

    for arg in vars(
            args
    ):  #Before training start writting the summary of the parameters
        text = arg + ': ' + str(getattr(
            args, arg))  #get the attribute name and value, make them as string
        writer.add_text('Parameters/', text)  #store the whole string

    # training
    for epoch in range(args.n_epoch):  #for loop on the number of epochs
        # Training Mode:
        model.train()  #initialize training mode
        loss_train, total_iteration = 0, 0  # intialize training loss and total number of iterations

        for i, (images, labels) in enumerate(
                trainloader
        ):  #start the epoch then initialize the number of iterations per epoch i is the batch number
            image_original, labels_original = images, labels  #store the image and label batch in new varaibles
            images, labels = images.to(device), labels.to(
                device)  #move images and labels to the GPU

            optimizer.zero_grad()  #intialize the optimizer
            outputs = model(
                images
            )  #feed forward the images through the model (outputs is a 7 channel o/p)

            pred = outputs.detach().max(1)[1].cpu().numpy(
            )  #get the model o/p from GPU, select the index of the maximum channel and send it back to CPU
            gt = labels.detach().cpu().numpy(
            )  #get the true lablels from GPU and send them to CPU
            running_metrics.update(
                gt, pred
            )  #call the function update and pass the ground truth and the predicted classes

            loss = loss_fn(input=outputs,
                           target=labels,
                           gamma=args.gamma,
                           loss_type=args.loss_parameters
                           )  #call the loss fuction to calculate the loss
            loss_train += loss.item()  #gets the scalar value held in the loss.
            loss.backward(
            )  # Use autograd to compute the backward pass. This call will compute the gradient of loss with respect to all Tensors with requires_grad=True.

            # gradient clipping
            if args.clip != 0:
                torch.nn.utils.clip_grad_norm(
                    model.parameters(), args.clip
                )  #The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

            optimizer.step(
            )  #step the optimizer (update the model weights with the new gradients)
            total_iteration = total_iteration + 1  #increment the total number of iterations by 1

            if (
                    i
            ) % 20 == 0:  #if 20% of the total number of iterations pass then
                print(
                    "Epoch [%d/%d] training Loss: %.4f" %
                    (epoch + 1, args.n_epoch, loss.item())
                )  #print the current epoch, total number of epochs and the current training loss

            numbers = [0, 14, 29, 49, 99]  #select some numbers
            if i in numbers:  #if the current batch number is in numbers
                # number 0 image in the batch
                tb_original_image = vutils.make_grid(
                    image_original[0][0], normalize=True, scale_each=True
                )  #select the first image in the batch create a tensorboard grid form the image tensor
                writer.add_image('train/original_image', tb_original_image,
                                 epoch + 1)  #send the image to writer

                labels_original = labels_original.numpy(
                )[0]  #convert the ground truth lablels of the first image in the batch to numpy array
                correct_label_decoded = train_set.decode_segmap(
                    np.squeeze(labels_original)
                )  #Decode segmentation class labels into a color image
                writer.add_image('train/original_label',
                                 np_to_tb(correct_label_decoded),
                                 epoch + 1)  #send the image to the writer
                out = F.softmax(outputs, dim=1)  #softmax of the network o/p
                prediction = out.max(1)[1].cpu().numpy()[
                    0]  #get the index of the maximum value after softmax
                confidence = out.max(1)[0].cpu().detach()[
                    0]  # this returns the confidence in the chosen class

                tb_confidence = vutils.make_grid(
                    confidence, normalize=True, scale_each=True
                )  #convert the confidence from tensor to image

                decoded = train_set.decode_segmap(np.squeeze(
                    prediction))  #Decode predicted classes to colours
                writer.add_image(
                    'train/predicted', np_to_tb(decoded), epoch + 1
                )  #send predicted map to writer along with the epoch number
                writer.add_image(
                    'train/confidence', tb_confidence, epoch + 1
                )  #send the confidence to writer along with the epoch number

                unary = outputs.cpu().detach(
                )  #get the Nw o/p for the whole batch
                unary_max = torch.max(
                    unary)  #normalize the Nw o/p w.r.t whole batch
                unary_min = torch.min(unary)
                unary = unary.add((-1 * unary_min))
                unary = unary / (unary_max - unary_min)

                for channel in range(0, len(class_names)):
                    decoded_channel = unary[0][
                        channel]  #get the normalized o/p for the first image in the batch
                    tb_channel = vutils.make_grid(
                        decoded_channel, normalize=True,
                        scale_each=True)  #prepare a image from tensor
                    writer.add_image(f'train_classes/_{class_names[channel]}',
                                     tb_channel,
                                     epoch + 1)  #send image to writer

        # Average metrics after finishing all batches for the whole epoch, and save in writer()
        loss_train /= total_iteration  #total loss for all iterations/ number of iterations
        score, class_iou = running_metrics.get_scores(
        )  #returns a dictionary of the calculated accuracy metrics and class iu
        writer.add_scalar(
            'train/Pixel Acc', score['Pixel Acc: '],
            epoch + 1)  # store the epoch metrics in the tensorboard writer
        writer.add_scalar('train/Mean Class Acc', score['Mean Class Acc: '],
                          epoch + 1)
        writer.add_scalar('train/Freq Weighted IoU',
                          score['Freq Weighted IoU: '], epoch + 1)
        writer.add_scalar('train/Mean_IoU', score['Mean IoU: '], epoch + 1)
        confusion = score['confusion_matrix']
        writer.add_image(f'train/confusion matrix', np_to_tb(confusion),
                         epoch + 1)

        running_metrics.reset()  #resets the confusion matrix
        writer.add_scalar('train/loss', loss_train,
                          epoch + 1)  #store the training loss
        #Finished one epoch of training, starting one epoch of testing
        if args.per_val != 0:  # if validation is required
            with torch.no_grad():  # operations inside don't track history
                # Validation Mode:
                model.eval()  #start validation mode
                loss_val, total_iteration_val = 0, 0  # initialize validation loss and total number of iterations

                for i_val, (images_val, labels_val) in tqdm(
                        enumerate(valloader)):  #start validation testing
                    image_original, labels_original = images_val, labels_val  #store original validation errors
                    images_val, labels_val = images_val.to(
                        device), labels_val.to(
                            device)  #send validation images and labels to GPU

                    outputs_val = model(images_val)  #feedforward the image
                    pred = outputs_val.detach().max(
                        1)[1].cpu().numpy()  #get the network class prediction
                    gt = labels_val.detach().cpu().numpy(
                    )  #get the ground truth from the GPU

                    running_metrics_val.update(
                        gt, pred)  #run metrics on the validation data

                    loss = loss_fn(input=outputs_val,
                                   target=labels_val,
                                   gamma=args.gamma,
                                   loss_type=args.loss_parameters
                                   )  #calculate the loss function
                    total_iteration_val = total_iteration_val + 1  #increment the loop counter

                    if (
                            i_val
                    ) % 20 == 0:  #After 20% of batches for validation print the validation loss
                        print("Epoch [%d/%d] validation Loss: %.4f" %
                              (epoch, args.n_epoch, loss.item()))

                    numbers = [0]
                    if i_val in numbers:  #select batch number 0
                        # number 0 image in the batch
                        tb_original_image = vutils.make_grid(
                            image_original[0][0],
                            normalize=True,
                            scale_each=True
                        )  #make first tensor in the batch as image
                        writer.add_image('val/original_image',
                                         tb_original_image,
                                         epoch)  #send image to writer
                        labels_original = labels_original.numpy()[
                            0]  #get origianl labels of image 0
                        correct_label_decoded = train_set.decode_segmap(
                            np.squeeze(labels_original)
                        )  #convert the labels to colour map
                        writer.add_image('val/original_label',
                                         np_to_tb(correct_label_decoded),
                                         epoch +
                                         1)  #send the coloured map to writer

                        out = F.softmax(
                            outputs_val,
                            dim=1)  #get soft max of the network 7 channel o/p

                        # this returns the max. channel number:
                        prediction = out.max(1)[1].cpu().detach().numpy(
                        )[0]  #get the position of the max o/p across different channels
                        # this returns the confidence:
                        confidence = out.max(1)[0].cpu().detach(
                        )[0]  #get the maximum o/p of the Nw across different channels
                        tb_confidence = vutils.make_grid(
                            confidence, normalize=True,
                            scale_each=True)  #convert tensor to image

                        decoded = train_set.decode_segmap(
                            np.squeeze(prediction)
                        )  #convert predicted classes to colour maps
                        writer.add_image('val/predicted', np_to_tb(decoded),
                                         epoch + 1)  #send prediction to writer
                        writer.add_image('val/confidence', tb_confidence,
                                         epoch + 1)  #send confidence to writer

                        unary = outputs.cpu().detach(
                        )  #get Nw o/p of the current batch
                        unary_max, unary_min = torch.max(unary), torch.min(
                            unary)  #normalize across all the Nw o/p
                        unary = unary.add((-1 * unary_min))
                        unary = unary / (unary_max - unary_min)

                        for channel in range(
                                0, len(class_names)
                        ):  #for all the 7 channels of the Nw op
                            tb_channel = vutils.make_grid(
                                unary[0][channel],
                                normalize=True,
                                scale_each=True
                            )  #convert the channel o/p of the class to image
                            writer.add_image(
                                f'val_classes/_{class_names[channel]}',
                                tb_channel, epoch + 1)  #send image to writer
                # finished one cycle of validation after iterating over all validation batched
                score, class_iou = running_metrics_val.get_scores(
                )  #returns a dictionary of the calculated accuracy metrics and class iu
                for k, v in score.items():  #??
                    print(k, v)

                writer.add_scalar('val/Pixel Acc', score['Pixel Acc: '],
                                  epoch + 1)  #send metrics to writer
                writer.add_scalar('val/Mean IoU', score['Mean IoU: '],
                                  epoch + 1)
                writer.add_scalar('val/Mean Class Acc',
                                  score['Mean Class Acc: '], epoch + 1)
                writer.add_scalar('val/Freq Weighted IoU',
                                  score['Freq Weighted IoU: '], epoch + 1)
                confusion = score['confusion_matrix']
                writer.add_image(f'val/confusion matrix', np_to_tb(confusion),
                                 epoch + 1)
                writer.add_scalar('val/loss', loss.item(), epoch + 1)
                running_metrics_val.reset()  #reset confusion matrix

                if score['Mean IoU: '] >= best_iou:  #compare with the validation mean iou of current epoch with the best stored validation mean IoU
                    best_iou = score[
                        'Mean IoU: ']  #if better, then store the better and store the current model as the best model
                    model_dir = os.path.join(log_dir,
                                             f"{args.arch}_model_best.pkl")
                    torch.save(model, model_dir)

                if epoch % 10 == 0:  #every 10 epochs store the current model
                    model_dir = os.path.join(
                        log_dir, f"{args.arch}_ep{epoch}_model.pkl")
                    torch.save(model, model_dir)

        else:  # validation is turned off:
            # just save the latest model every 10 epochs:
            if (epoch + 1) % 10 == 0:
                model_dir = os.path.join(
                    log_dir, f"{args.arch}_ep{epoch + 1}_model.pkl")
                torch.save(model, model_dir)

    writer.close()  #close the writer
コード例 #30
0
class Train:
    __device = []
    __writer = []
    __model = []
    __transformations = []
    __dataset_train = []
    __train_loader = []
    __loss_func = []
    __optimizer = []
    __exp_lr_scheduler = []

    def __init__(self, gpu='0'):
        # Device configuration
        self.__device = torch.device('cuda:'+gpu if torch.cuda.is_available() else 'cpu')
        self.__writer = SummaryWriter('logs')
        self.__model = CNNDriver()
        # Set model to train mode
        self.__model.train()
        print(self.__model)
        self.__writer.add_graph(self.__model, torch.rand(10, 3, 66, 200))
        # Put model on GPU
        self.__model = self.__model.to(self.__device)

    def train(self, num_epochs=100, batch_size=400, lr=0.0001, l2_norm=0.001, save_dir='./save', input='./DataLMDB'):
        # Create log/save directory if it does not exist
        if not os.path.exists('./logs'):
            os.makedirs('./logs')
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.__transformations = transforms.Compose([AugmentDrivingTransform(), 
                                                     RandomBrightness(), ConvertToGray(), 
                                                     ConvertToSepia(), AddNoise(), DrivingDataToTensor(),])
        self.__dataset_train = DriveData_LMDB(input, self.__transformations)
        self.__train_loader = DataLoader(self.__dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)

        # Loss and Optimizer
        self.__loss_func = nn.MSELoss()
        # self.__loss_func = nn.SmoothL1Loss()
        self.__optimizer = torch.optim.Adam(self.__model.parameters(), lr=lr, weight_decay=l2_norm)

        # Decay LR by a factor of 0.1 every 10 epochs
        self.__exp_lr_scheduler = lr_scheduler.StepLR(self.__optimizer, step_size=15, gamma=0.1)

        print('Train size:', len(self.__dataset_train), 'Batch size:', batch_size)
        print('Batches per epoch:', len(self.__dataset_train) // batch_size)

        # Train the Model
        iteration_count = 0
        for epoch in range(num_epochs):
            for batch_idx, samples in enumerate(self.__train_loader):

                # Send inputs/labels to GPU
                images = samples['image'].to(self.__device)
                labels = samples['label'].to(self.__device)

                self.__optimizer.zero_grad()

                # Forward + Backward + Optimize
                outputs = self.__model(images)
                loss = self.__loss_func(outputs, labels.unsqueeze(dim=1))

                loss.backward()
                self.__optimizer.step()
                self.__exp_lr_scheduler.step(epoch)

                # Send loss to tensorboard
                self.__writer.add_scalar('loss/', loss.item(), iteration_count)
                self.__writer.add_histogram('steering_out', outputs.clone().detach().cpu().numpy(), iteration_count, bins='doane')
                self.__writer.add_histogram('steering_in', 
                                            labels.unsqueeze(dim=1).clone().detach().cpu().numpy(), iteration_count, bins='doane')

                # Get current learning rate (To display on Tensorboard)
                for param_group in self.__optimizer.param_groups:
                    curr_learning_rate = param_group['lr']
                    self.__writer.add_scalar('learning_rate/', curr_learning_rate, iteration_count)

                # Display on each epoch
                if batch_idx == 0:
                    # Send image to tensorboard
                    self.__writer.add_image('Image', images, epoch)
                    self.__writer.add_text('Steering', 'Steering:' + str(outputs[batch_idx].item()), epoch)
                    # Print Epoch and loss
                    print('Epoch [%d/%d] Loss: %.4f' % (epoch + 1, num_epochs, loss.item()))
                    # Save the Trained Model parameters
                    torch.save(self.__model.state_dict(), save_dir+'/cnn_' + str(epoch) + '.pkl')

                iteration_count += 1
コード例 #31
0
def training(main_args, model_args, model=None):
    if "task_type" in model_args and model_args.task_type is not None:
        main_args.task_type = model_args.task_type
    dir_ret = get_model_info(main_args=main_args, model_args=model_args)
    model, optimizer, vocab = build_model(main_args=main_args,
                                          model_args=model_args,
                                          model=model)
    train_set, dev_set = load_data(main_args)
    model_file = dir_ret['model_file']
    log_dir = dir_ret['log_dir']
    out_dir = dir_ret['out_dir']

    writer = SummaryWriter(log_dir)
    GlobalOps.writer_ops = writer
    writer.add_text("main_args", str(main_args))
    writer.add_text("model_args", str(model_args))

    print("...... Start Training ......")

    train_iter = main_args.start_iter
    train_nums, train_loss = 0., 0.
    epoch, num_trial, patience, = 0, 0, 0

    history_scores = []
    task_type = main_args.task_type
    eval_select = eval_key_dict[task_type.lower()]
    sort_key = sort_key_dict[
        model_args.sort_key] if "sort_key" in model_args else sort_key_dict[
            model_args.enc_type]
    evaluator = get_evaluator(eval_choice=eval_select,
                              model=model,
                              eval_set=dev_set.examples,
                              eval_lists=main_args.eval_lists,
                              sort_key=sort_key,
                              eval_tgt="tgt",
                              batch_size=model_args.eval_bs,
                              out_dir=out_dir,
                              write_down=True)
    print("Dev ITEM: ", evaluator.score_item)
    adv_training = "adv_train" in model_args and model_args.adv_train
    adv_syn, adv_sem = False, False

    def hyper_init():
        adv_syn_ = (model_args.adv_syn +
                    model_args.infer_weight * model_args.inf_sem) > 0.
        adv_sem_ = (model_args.adv_sem +
                    model_args.infer_weight * model_args.inf_syn) > 0.
        return adv_syn_, adv_sem_

    if adv_training:
        adv_syn, adv_sem = hyper_init()

    def normal_training():
        optimizer.zero_grad()
        batch_ret_ = model.get_loss(examples=batch_examples,
                                    return_enc_state=False,
                                    train_iter=train_iter)
        batch_loss_ = batch_ret_['Loss']
        return batch_ret_

    def universal_training():
        if adv_training:
            ret_loss = model.get_loss(examples=batch_examples,
                                      train_iter=train_iter,
                                      is_dis=True)
            if adv_syn:
                dis_syn_loss = ret_loss['dis syn']
                optimizer.zero_grad()
                dis_syn_loss.backward()
                if main_args.clip_grad > 0.:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   main_args.clip_grad)
                optimizer.step()
            if adv_sem:
                ret_loss = model.get_loss(examples=batch_examples,
                                          train_iter=train_iter,
                                          is_dis=True)
                dis_sem_loss = ret_loss['dis sem']
                optimizer.zero_grad()
                dis_sem_loss.backward()
                if main_args.clip_grad > 0.:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   main_args.clip_grad)
                optimizer.step()
        return normal_training()

    while True:
        epoch += 1
        epoch_begin = time.time()
        train_log_dict = {}

        for batch_examples in train_set.batch_iter(
                batch_size=main_args.batch_size, shuffle=True):
            train_iter += 1
            train_nums += len(batch_examples)
            # batch_ret = model.get_loss(examples=batch_examples, return_enc_state=False, train_iter=train_iter)
            batch_ret = universal_training()
            batch_loss = batch_ret['Loss']
            train_loss += batch_loss.sum().item()
            torch.mean(batch_loss).backward()

            if main_args.clip_grad > 0.:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               main_args.clip_grad)
            optimizer.step()

            train_log_dict = update_tracker(batch_ret, train_log_dict)

            if train_iter % main_args.log_every == 0:
                print('\r[Iter %d] Train loss=%.5f' %
                      (train_iter, train_loss / train_nums),
                      file=sys.stdout,
                      end=" ")
                for key, val in train_log_dict.items():
                    if isinstance(val, torch.Tensor):
                        writer.add_scalar(tag="{}/Train/{}".format(
                            task_type, key),
                                          scalar_value=torch.mean(val).item(),
                                          global_step=train_iter)
                writer.add_scalar(tag="Optimize/lr",
                                  scalar_value=optimizer.param_groups[0]['lr'],
                                  global_step=train_iter)
                writer.add_scalar(
                    tag='Optimize/trial',
                    scalar_value=num_trial,
                    global_step=train_iter,
                )
                writer.add_scalar(
                    tag='Optimize/patience',
                    scalar_value=patience,
                    global_step=train_iter,
                )

            if train_iter % main_args.dev_every == 0 and train_iter > model_args.warm_up:
                eval_result_dict = evaluator()
                dev_acc = eval_result_dict[evaluator.score_item]
                if isinstance(dev_acc, torch.Tensor):
                    dev_acc = dev_acc.sum().item()
                print('\r[Iter %d] %s %s=%.3f took %d s' %
                      (train_iter, task_type, evaluator.score_item, dev_acc,
                       eval_result_dict['EVAL TIME']),
                      file=sys.stdout)
                is_better = (history_scores
                             == []) or dev_acc > max(history_scores)
                history_scores.append(dev_acc)

                writer.add_scalar(tag='%s/Valid/Best %s' %
                                  (task_type, evaluator.score_item),
                                  scalar_value=max(history_scores),
                                  global_step=train_iter)
                for key, val in eval_result_dict.items():
                    writer.add_scalar(tag="{}/Valid/{}".format(task_type, key),
                                      scalar_value=val.sum().item() if
                                      isinstance(val, torch.Tensor) else val,
                                      global_step=train_iter)

                model, optimizer, num_trial, patience = get_lr_schedule(
                    is_better=is_better,
                    model_file=model_file,
                    main_args=main_args,
                    patience=patience,
                    num_trial=num_trial,
                    model=model,
                    optimizer=optimizer,
                    reload_model=False)
                model.train()

        epoch_time = time.time() - epoch_begin
        print('\r[Epoch %d] epoch elapsed %ds' % (epoch, epoch_time),
              file=sys.stdout)