示例#1
0
文件: train.py 项目: ihaeyong/r2c
test_loader = VCRLoader.from_dataset(test, **loader_params)

ARGS_RESET_EVERY = 100
print("Loading {} for {}".format(params['model'].get('type', 'WTF?'),
                                 'rationales' if args.rationale else 'answer'),
      flush=True)
model = Model.from_params(vocab=train.vocab, params=params['model'])
for submodule in model.detector.backbone.modules():
    if isinstance(submodule, BatchNorm2d):
        submodule.track_running_stats = False
    for p in submodule.parameters():
        p.requires_grad = False

model = DataParallel(model).cuda() if NUM_GPUS > 1 else model.cuda()
optimizer = Optimizer.from_params(
    [x for x in model.named_parameters() if x[1].requires_grad],
    params['trainer']['optimizer'])

lr_scheduler_params = params['trainer'].pop("learning_rate_scheduler", None)
scheduler = LearningRateScheduler.from_params(
    optimizer, lr_scheduler_params) if lr_scheduler_params else None

if os.path.exists(args.folder):
    print("Found folder! restoring", flush=True)
    start_epoch, val_metric_per_epoch = restore_checkpoint(
        model,
        optimizer,
        serialization_dir=args.folder,
        learning_rate_scheduler=scheduler)
else:
    print("Making directories")
示例#2
0
def main(model_args):
    args = get_config(PATH=model_args.model_path,
                      config_json_name=model_args.model_config_name)
    args.check_point = model_args.model_name
    args.data_path = model_args.data_path
    args.test_batch_size = model_args.test_batch_size
    args.doc_threshold = model_args.doc_threshold
    args.save_path = model_args.model_path
    if torch.cuda.is_available():
        args.cuda = True
    else:
        args.cuda = False
    ###################
    if args.data_path is None:
        raise ValueError('one of data_path must be chosed.')
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    set_logger(args)
    ########+++++++++++++++++++++++++++++
    abs_path = os.path.abspath(args.data_path)
    args.data_path = abs_path
    ########+++++++++++++++++++++++++++++
    # Write logs to checkpoint and console
    if args.cuda:
        if args.gpu_num > 1:
            device_ids, used_memory = gpu_setting(args.gpu_num)
        else:
            device_ids, used_memory = gpu_setting()
        if used_memory > 100:
            logging.info('Using memory = {}'.format(used_memory))
        if device_ids is not None:
            if len(device_ids) > args.gpu_num:
                device_ids = device_ids[:args.gpu_num]
            device = torch.device('cuda:{}'.format(device_ids[0]))
        else:
            device = torch.device('cuda:0')
        logging.info('Set the cuda with idxes = {}'.format(device_ids))
        logging.info('cuda setting {}'.format(device))
        logging.info('GPU setting')
    else:
        device_ids = None
        device = torch.device('cpu')
        logging.info('CPU setting')
    ########+++++++++++++++++++++++++++++
    logging.info('Loading development data...')
    test_data_loader, _ = get_test_data_loader(args=args)
    logging.info('Loading data completed')
    logging.info('*' * 75)
    # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    logging.info('Loading Model...')
    model = get_model(args=args).to(device)
    ##+++++++++++
    model_path = args.save_path
    model_file_name = args.check_point
    hotpot_qa_model_name = os.path.join(model_path, model_file_name)
    model = load_model(model=model, PATH=hotpot_qa_model_name)
    model = model.to(device)
    if device_ids is not None:
        if len(device_ids) > 1:
            model = DataParallel(model,
                                 device_ids=device_ids,
                                 output_device=device)
            logging.info('Data Parallel model setting')
    ##+++++++++++
    logging.info('Model Parameter Configuration:')
    for name, param in model.named_parameters():
        logging.info('Parameter {}: {}, require_grad = {}'.format(
            name, str(param.size()), str(param.requires_grad)))
    logging.info('*' * 75)
    logging.info("Model hype-parameter information...")
    for key, value in vars(args).items():
        logging.info('Hype-parameter\t{} = {}'.format(key, value))
    logging.info('*' * 75)
    logging.info("Model hype-parameter information...")
    for key, value in vars(model_args).items():
        logging.info('Hype-parameter\t{} = {}'.format(key, value))
    logging.info('*' * 75)
    logging.info('projection_dim = {}'.format(args.project_dim))
    logging.info('Multi-task encoding')
    logging.info('*' * 75)
    logging.info('Loading tokenizer')
    tokenizer = get_hotpotqa_longformer_tokenizer()
    logging.info('*' * 75)
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    # logging.info('Multi-task encoding')
    # metric_dict = multi_task_decoder(model=model, device=device, test_data_loader=test_data_loader, args=args)
    # answer_type_acc = metric_dict['answer_type_acc']
    # logging.info('*' * 75)
    # logging.info('Answer type prediction accuracy: {}'.format(answer_type_acc))
    # logging.info('*' * 75)
    # for key, value in metric_dict.items():
    #     if key.endswith('metrics'):
    #         logging.info('{} prediction'.format(key))
    #         log_metrics('Valid', 'final', value)
    # logging.info('*' * 75)
    # ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    # dev_data_frame = metric_dict['res_dataframe']
    # ##################################################
    # leadboard_metric, res_data_frame = convert2leadBoard(data=dev_data_frame, tokenizer=tokenizer)
    # ##=================================================
    # logging.info('*' * 75)
    # log_metrics('Evaluation', step='leadboard', metrics=leadboard_metric)
    # logging.info('*' * 75)
    # date_time_str = get_date_time()
    # dev_result_name = os.path.join(args.save_path,
    #                                date_time_str + '_mt_evaluation.json')
    # res_data_frame.to_json(dev_result_name, orient='records')
    # logging.info('Saving {} record results to {}'.format(res_data_frame.shape, dev_result_name))
    # logging.info('*' * 75)
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    logging.info('Hierarchical encoding')
    metric_dict = hierartical_decoder(model=model,
                                      device=device,
                                      test_data_loader=test_data_loader,
                                      doc_topk=model_args.doc_topk,
                                      args=args)
    answer_type_acc = metric_dict['answer_type_acc']
    logging.info('*' * 75)
    logging.info('Answer type prediction accuracy: {}'.format(answer_type_acc))
    logging.info('*' * 75)
    for key, value in metric_dict.items():
        if key.endswith('metrics'):
            logging.info('{} prediction'.format(key))
            log_metrics('Valid', 'final', value)
        logging.info('*' * 75)
    ##++++++++++++++++++++++++++++++++++++++++++++++++++++
    topk_dev_data_frame = metric_dict['topk_dataframe']
    ##################################################
    topk_leadboard_metric, topk_res_data_frame = convert2leadBoard(
        data=topk_dev_data_frame, tokenizer=tokenizer)
    ##=================================================
    log_metrics('Topk Evaluation',
                step='leadboard',
                metrics=topk_leadboard_metric)
    date_time_str = get_date_time()
    topk_dev_result_name = os.path.join(
        args.save_path, date_time_str + '_topk_hi_evaluation.json')
    topk_res_data_frame.to_json(topk_dev_result_name, orient='records')
    logging.info('Saving {} record results to {}'.format(
        topk_res_data_frame.shape, topk_dev_result_name))
    logging.info('*' * 75)
    ##=================================================
    thresh_dev_data_frame = metric_dict['thresh_dataframe']
    ##################################################
    thresh_leadboard_metric, thresh_res_data_frame = convert2leadBoard(
        data=thresh_dev_data_frame, tokenizer=tokenizer)
    log_metrics('Thresh Evaluation',
                step='leadboard',
                metrics=thresh_leadboard_metric)
    ##=================================================
    date_time_str = get_date_time()
    thresh_dev_result_name = os.path.join(
        args.save_path, date_time_str + '_thresh_hi_evaluation.json')
    thresh_res_data_frame.to_json(thresh_dev_result_name, orient='records')
    logging.info('Saving {} record results to {}'.format(
        thresh_res_data_frame.shape, thresh_dev_result_name))
    logging.info('*' * 75)
示例#3
0
class Solver():
    def __init__(self, config, channel_list):
        # Config - Model
        self.z_dim = config.z_dim
        self.channel_list = channel_list

        # Config - Training
        self.batch_size = config.batch_size
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.decay_ratio = config.decay_ratio
        self.decay_iter = config.decay_iter
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.n_critic = config.n_critic
        self.lambda_gp = config.lambda_gp
        self.max_iter = config.max_iter

        self.r1_iter = config.r1_iter
        self.r1_lambda = config.r1_lambda
        self.ppl_iter = config.ppl_iter
        self.ppl_lambda = config.ppl_lambda

        # Config - Test
        self.fixed_z = torch.randn(512, config.z_dim).to(dev)

        # Config - Path
        self.data_root = config.data_root
        self.log_root = config.log_root
        self.model_root = config.model_root
        self.sample_root = config.sample_root

        # Config - Miscellanceous
        self.print_loss_iter = config.print_loss_iter
        self.save_image_iter = config.save_image_iter
        self.save_parameter_iter = config.save_parameter_iter
        self.save_log_iter = config.save_log_iter

        self.writer = SummaryWriter(self.log_root)

    def build_model(self):
        self.G = Generator(channel_list=self.channel_list)
        self.G_ema = Generator(channel_list=self.channel_list)
        self.D = Discriminator(channel_list=self.channel_list)
        self.M = MappingNetwork(z_dim=self.z_dim)

        self.G = DataParallel(self.G).to(dev)
        self.G_ema = DataParallel(self.G_ema).to(dev)
        self.D = DataParallel(self.D).to(dev)
        self.M = DataParallel(self.M).to(dev)

        G_M_params = list(self.G.parameters()) + list(self.M.parameters())

        self.g_optimizer = torch.optim.Adam(params=G_M_params,
                                            lr=self.g_lr,
                                            betas=[self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(params=self.D.parameters(),
                                            lr=self.d_lr,
                                            betas=[self.beta1, self.beta2])

        self.g_scheduler = lr_scheduler.StepLR(self.g_optimizer,
                                               step_size=self.decay_iter,
                                               gamma=self.decay_ratio)
        self.d_scheduler = lr_scheduler.StepLR(self.d_optimizer,
                                               step_size=self.decay_iter,
                                               gamma=self.decay_ratio)

        print("Print model G, D")
        print(self.G)
        print(self.D)

    def load_model(self, pkl_path, channel_list):
        ckpt = torch.load(pkl_path)

        self.G = Generator(channel_list=channel_list)
        self.G_ema = Generator(channel_list=channel_list)
        self.D = Discriminator(channel_list=channel_list)
        self.M = MappingNetwork(z_dim=self.z_dim)

        self.G = DataParallel(self.G).to(dev)
        self.G_ema = DataParallel(self.G_ema).to(dev)
        self.D = DataParallel(self.D).to(dev)
        self.M = DataParallel(self.M).to(dev)

        self.G.load_state_dict(ckpt["G"])
        self.G_ema.load_state_dict(ckpt["G_ema"])
        self.D.load_state_dict(ckpt["D"])
        self.M.load_state_dict(ckpt["M"])

    def save_model(self, iters):
        file_name = 'ckpt_%d.pkl' % iters
        ckpt_path = os.path.join(self.model_root, file_name)
        ckpt = {
            'M': self.M.state_dict(),
            'G': self.G.state_dict(),
            'G_ema': self.G_ema.state_dict(),
            'D': self.D.state_dict()
        }
        torch.save(ckpt, ckpt_path)

    def save_img(self, iters, fixed_w):
        img_path = os.path.join(self.sample_root, "%d.png" % iters)
        with torch.no_grad():
            fixed_w = fixed_w[:self.batch_size * 2]
            dlatents_in = make_latents(fixed_w, self.batch_size,
                                       len(self.channel_list))
            generated_imgs, _ = self.G_ema(dlatents_in)
            save_image(
                make_grid(generated_imgs.cpu() / 2 + 1 / 2, nrow=4, padding=2),
                img_path)

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def lr_update(self):
        self.g_scheduler.step()
        self.d_scheduler.step()

    def set_phase(self, mode="train"):
        if mode == "train":
            self.G.train()
            self.G_ema.train()
            self.D.train()
            self.M.train()

        elif mode == "test":
            self.G.eval()
            self.G_ema.eval()
            self.D.eval()
            self.M.eval()

    def exponential_moving_average(self, beta=0.999):
        with torch.no_grad():
            G_param_dict = dict(self.G.named_parameters())
            for name, g_ema_param in self.G_ema.named_parameters():
                g_param = G_param_dict[name]
                g_ema_param.copy_(beta * g_ema_param + (1. - beta) * g_param)

    def r1_regularization(self, real_pred, real_img):
        grad_real = torch.autograd.grad(outputs=real_pred.sum(),
                                        inputs=real_img,
                                        create_graph=True)[0]
        grad_penalty = grad_real.pow(2).view(grad_real.size(0),
                                             -1).sum(1).mean()
        return grad_penalty

    def path_length_regularization(self,
                                   fake_img,
                                   latents,
                                   mean_path_length,
                                   decay=0.01):
        noise = torch.randn_like(fake_img) / math.sqrt(
            fake_img.shape[2] * fake_img.shape[3])
        grad = torch.autograd.grad(outputs=(fake_img * noise).sum(),
                                   inputs=latents,
                                   create_graph=True)[0]
        path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
        path_mean = mean_path_length + decay * (path_lengths.mean() -
                                                mean_path_length)
        path_penalty = (path_lengths - path_mean).pow(2).mean()
        return path_penalty, path_mean.detach(), path_lengths

    def train(self):
        # build model
        self.build_model()
        loader = data_loader(self.data_root, self.batch_size, img_size=512)
        loader = iter(cycle(loader))
        mean_path_length = torch.tensor(0.0).to(dev)
        average_path_length = torch.tensor(0.0).to(dev)

        for iters in tqdm(range(self.max_iter + 1)):
            real_img = next(loader)
            real_img = real_img.to(dev)
            # ===============================================================#
            #                    1. Train the discriminator                  #
            # ===============================================================#
            self.set_phase(mode="train")
            self.reset_grad()

            # Compute loss with real images.
            d_real_out = self.D(real_img)
            d_loss_real = F.softplus(-d_real_out).mean()

            # Compute loss with face images.
            z = torch.randn(2 * self.batch_size, self.z_dim).to(dev)
            w = self.M(z)
            dlatents_in = make_latents(w, self.batch_size,
                                       len(self.channel_list))
            fake_img, _ = self.G(dlatents_in)
            d_fake_out = self.D(fake_img.detach())
            d_loss_fake = F.softplus(d_fake_out).mean()

            d_loss = d_loss_real + d_loss_fake

            if iters % self.r1_iter == 0:
                real_img.requires_grad = True
                d_real_out = self.D(real_img)
                r1_loss = self.r1_regularization(d_real_out, real_img)
                r1_loss = self.r1_lambda / 2 * r1_loss * self.r1_iter
                d_loss = d_loss + r1_loss

            d_loss.backward()
            self.d_optimizer.step()
            # ===============================================================#
            #                      2. Train the Generator                    #
            # ===============================================================#

            if (iters + 1) % self.n_critic == 0:
                self.reset_grad()

                # Compute loss with fake images.
                z = torch.randn(2 * self.batch_size, self.z_dim).to(dev)
                w = self.M(z)
                dlatents_in = make_latents(w, self.batch_size,
                                           len(self.channel_list))
                fake_img, _ = self.G(dlatents_in)
                d_fake_out = self.D(fake_img)
                g_loss = F.softplus(-d_fake_out).mean()

                if iters % self.ppl_iter == 0:
                    path_loss, mean_path_length, path_length = self.path_length_regularization(
                        fake_img, dlatents_in, mean_path_length)
                    path_loss = path_loss * self.ppl_iter * self.ppl_lambda
                    g_loss = g_loss + path_loss
                    mean_path_length = mean_path_length.mean()
                    average_path_length += mean_path_length.mean()

                # Backward and optimize.
                g_loss.backward()
                self.g_optimizer.step()

            # ===============================================================#
            #                   3. Save parameters and images                #
            # ===============================================================#
            # self.lr_update()
            torch.cuda.synchronize()
            self.set_phase(mode="test")
            self.exponential_moving_average()

            # Print total loss
            if iters % self.print_loss_iter == 0:
                print(
                    "Iter : [%d/%d], D_loss : [%.3f, %.3f, %.3f.], G_loss : %.3f, R1_reg : %.3f, "
                    "PPL_reg : %.3f, Path_length : %.3f" %
                    (iters, self.max_iter, d_loss.item(), d_loss_real.item(),
                     d_loss_fake.item(), g_loss.item(), r1_loss.item(),
                     path_loss.item(), mean_path_length.item()))

            # Save generated images.
            if iters % self.save_image_iter == 0:
                fixed_w = self.M(self.fixed_z)
                self.save_img(iters, fixed_w)

            # Save the G and D parameters.
            if iters % self.save_parameter_iter == 0:
                self.save_model(iters)

            # Save the logs on the tensorboard.
            if iters % self.save_log_iter == 0:
                self.writer.add_scalar('g_loss/g_loss', g_loss.item(), iters)
                self.writer.add_scalar('d_loss/d_loss_total', d_loss.item(),
                                       iters)
                self.writer.add_scalar('d_loss/d_loss_real',
                                       d_loss_real.item(), iters)
                self.writer.add_scalar('d_loss/d_loss_fake',
                                       d_loss_fake.item(), iters)
                self.writer.add_scalar('reg/r1_regularization', r1_loss.item(),
                                       iters)
                self.writer.add_scalar('reg/ppl_regularization',
                                       path_loss.item(), iters)

                self.writer.add_scalar('length/path_length',
                                       mean_path_length.item(), iters)
                self.writer.add_scalar(
                    'length/avg_path_length',
                    average_path_length.item() / (iters // self.ppl_iter + 1),
                    iters)
示例#4
0
                              transform=transform)
training_dataloader = data.DataLoader(training_dataset,
                                      batch_size=args.bs,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True)

model = train_net()
model.apply(weights_init)
load_vgg16pretrain(model)
model = DataParallel(model).cuda()

weight = []
bias = []

for name, p in model.named_parameters():
    if 'weight' in name:
        weight.append(p)
    else:
        bias.append(p)

optimizer = optim.SGD([{
    "params": weight,
    "lr": args.lr,
    "weight_decay": 0
}, {
    "params": bias,
    "lr": 2 * args.lr,
    "weight_decay": args.wd
}],
                      momentum=args.momentum)
示例#5
0
class ProGAN:
    """ Wrapper around the Generator and the Discriminator """
    def __init__(self,
                 depth=7,
                 latent_size=512,
                 learning_rate=0.001,
                 beta_1=0,
                 beta_2=0.99,
                 eps=1e-8,
                 drift=0.001,
                 n_critic=1,
                 use_eql=True,
                 loss="wgan-gp",
                 use_ema=True,
                 ema_decay=0.999,
                 device=th.device("cpu")):
        """
        constructor for the class
        :param depth: depth of the GAN (will be used for each generator and discriminator)
        :param latent_size: latent size of the manifold used by the GAN
        :param learning_rate: learning rate for Adam
        :param beta_1: beta_1 for Adam
        :param beta_2: beta_2 for Adam
        :param eps: epsilon for Adam
        :param n_critic: number of times to update discriminator
                         (Used only if loss is wgan or wgan-gp)
        :param drift: drift penalty for the
                      (Used only if loss is wgan or wgan-gp)
        :param use_eql: whether to use equalized learning rate
        :param loss: the loss function to be used
                     Can either be a string =>
                          ["wgan-gp", "wgan", "lsgan", "lsgan-with-sigmoid"]
                     Or an instance of GANLoss
        :param use_ema: boolean for whether to use exponential moving averages
        :param ema_decay: value of mu for ema
        :param device: device to run the GAN on (GPU / CPU)
        """

        from torch.optim import Adam
        from torch.nn import DataParallel

        # Create the Generator and the Discriminator
        self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)
        self.dis = Discriminator(depth, latent_size,
                                 use_eql=use_eql).to(device)

        # if code is to be run on GPU, we can use DataParallel:
        if device == th.device("cuda"):
            self.gen = DataParallel(self.gen)
            self.dis = DataParallel(self.dis)

        # state of the object
        self.latent_size = latent_size
        self.depth = depth
        self.use_ema = use_ema
        self.ema_decay = ema_decay
        self.n_critic = n_critic
        self.use_eql = use_eql
        self.device = device
        self.drift = drift

        # define the optimizers for the discriminator and generator
        self.gen_optim = Adam(self.gen.parameters(),
                              lr=learning_rate,
                              betas=(beta_1, beta_2),
                              eps=eps)

        self.dis_optim = Adam(self.dis.parameters(),
                              lr=learning_rate,
                              betas=(beta_1, beta_2),
                              eps=eps)

        # define the loss function used for training the GAN
        self.loss = self.__setup_loss(loss)

        # setup the ema for the generator
        if self.use_ema:
            from pro_gan_pytorch.CustomLayers import EMA
            self.ema = EMA(self.ema_decay)
            self.__register_generator_to_ema()

    def __register_generator_to_ema(self):
        for name, param in self.gen.named_parameters():
            if param.requires_grad:
                self.ema.register(name, param.data)

    def __apply_ema_on_generator(self):
        for name, param in self.gen.named_parameters():
            if param.requires_grad:
                param.data = self.ema(name, param.data)

    def __setup_loss(self, loss):
        import pro_gan_pytorch.Losses as losses

        if isinstance(loss, str):
            loss = loss.lower()  # lowercase the string
            if loss == "wgan":
                loss = losses.WGAN_GP(self.device,
                                      self.dis,
                                      self.drift,
                                      use_gp=False)
                # note if you use just wgan, you will have to use weight clipping
                # in order to prevent gradient exploding

            elif loss == "wgan-gp":
                loss = losses.WGAN_GP(self.device,
                                      self.dis,
                                      self.drift,
                                      use_gp=True)

            elif loss == "lsgan":
                loss = losses.LSGAN(self.device, self.dis)

            elif loss == "lsgan-with-sigmoid":
                loss = losses.LSGAN_SIGMOID(self.device, self.dis)

            else:
                raise ValueError("Unknown loss function requested")

        elif not isinstance(loss, losses.GANLoss):
            raise ValueError(
                "loss is neither an instance of GANLoss nor a string")

        return loss

    def optimize_discriminator(self, noise, real_batch, depth, alpha):
        """
        performs one step of weight update on discriminator using the batch of data
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
        :param depth: current depth of optimization
        :param alpha: current alpha for fade-in
        :return: current loss (Wasserstein loss)
        """
        from torch.nn import AvgPool2d
        from torch.nn.functional import upsample

        # downsample the real_batch for the given depth
        down_sample_factor = int(np.power(2, self.depth - depth - 1))
        prior_downsample_factor = max(int(np.power(2, self.depth - depth)), 0)

        ds_real_samples = AvgPool2d(down_sample_factor)(real_batch)

        if depth > 0:
            prior_ds_real_samples = upsample(
                AvgPool2d(prior_downsample_factor)(real_batch), scale_factor=2)
        else:
            prior_ds_real_samples = ds_real_samples

        # real samples are a combination of ds_real_samples and prior_ds_real_samples
        real_samples = (alpha * ds_real_samples) + (
            (1 - alpha) * prior_ds_real_samples)

        loss_val = 0
        for _ in range(self.n_critic):
            # generate a batch of samples
            fake_samples = self.gen(noise, depth, alpha).detach()

            loss = self.loss.dis_loss(real_samples, fake_samples, depth, alpha)

            # optimize discriminator
            self.dis_optim.zero_grad()
            loss.backward()
            self.dis_optim.step()

            loss_val += loss.item()

        return loss_val / self.n_critic

    def optimize_generator(self, noise, depth, alpha):
        """
        performs one step of weight update on generator for the given batch_size
        :param noise: input random noise required for generating samples
        :param depth: depth of the network at which optimization is done
        :param alpha: value of alpha for fade-in effect
        :return: current loss (Wasserstein estimate)
        """

        # generate fake samples:
        fake_samples = self.gen(noise, depth, alpha)

        # TODO: Change this implementation for making it compatible for relativisticGAN
        loss = self.loss.gen_loss(None, fake_samples, depth, alpha)

        # optimize the generator
        self.gen_optim.zero_grad()
        loss.backward()
        self.gen_optim.step()

        # if use_ema is true, apply ema to the generator parameters
        if self.use_ema:
            self.__apply_ema_on_generator()

        # return the loss value
        return loss.item()
示例#6
0
def main(args):
    set_seeds(args.rand_seed)
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be chosed.')

    if args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be chosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained reasonModel?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    ########+++++++++++++++++++++++++++++
    abs_path = os.path.abspath(args.data_path)
    args.data_path = abs_path
    ########+++++++++++++++++++++++++++++
    # Write logs to checkpoint and console
    set_logger(args)
    if args.cuda:
        if args.do_debug:
            if args.gpu_num > 1:
                device_ids, used_memory = gpu_setting(args.gpu_num)
            else:
                device_ids, used_memory = gpu_setting()
            if used_memory > 100:
                logging.info('Using memory = {}'.format(used_memory))
            if device_ids is not None:
                if len(device_ids) > args.gpu_num:
                    device_ids = device_ids[:args.gpu_num]
                device = torch.device('cuda:{}'.format(device_ids[0]))
            else:
                device = torch.device('cuda:0')
            logging.info('Set the cuda with idxes = {}'.format(device_ids))
            logging.info('cuda setting {}'.format(device))
        else:
            if args.gpu_num > 1:
                logging.info("Using GPU!")
                available_device_count = torch.cuda.device_count()
                logging.info('GPU number is {}'.format(available_device_count))
                if args.gpu_num > available_device_count:
                    args.gpu_num = available_device_count
                # ++++++++++++++++++++++++++++++++++
                device_ids, used_memory = gpu_setting(args.gpu_num)
                # ++++++++++++++++++++++++++++++++++
                device = torch.device("cuda:{}".format(device_ids[0]))
                # ++++++++++++++++++++++++++++++++++
            else:
                device = torch.device("cuda:0")
                device_ids = None
                logging.info('Single GPU setting')
    else:
        device = torch.device('cpu')
        device_ids = None
        logging.info('CPU setting')

    logging.info('Device = {}, Device ids = {}'.format(device, device_ids))

    logging.info('Loading training data...')
    train_data_loader, train_data_size = get_train_data_loader(args=args)
    estimated_max_steps = args.epoch * (
        (train_data_size // args.batch_size) + 1)
    if estimated_max_steps > args.max_steps:
        args.max_steps = args.epoch * (
            (train_data_size // args.batch_size) + 1)
    logging.info('Loading development data...')
    dev_data_loader, _ = get_dev_data_loader(args=args)
    logging.info('Loading data completed')
    logging.info('*' * 75)
    tokenizer = get_hotpotqa_longformer_tokenizer()
    logging.info('*' * 75)
    # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    if args.do_train:
        # Set training configuration
        start_time = time()
        logging.info('Loading Model...')
        model = get_model(args=args).to(device)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     weight_decay=args.weight_decay)
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if device_ids is not None:
            if len(device_ids) > 1:
                model = DataParallel(model,
                                     device_ids=device_ids,
                                     output_device=device)
                logging.info('Data Parallel model setting')
        logging.info('Model Parameter Configuration:')
        for name, param in model.named_parameters():
            logging.info('Parameter {}: {}, require_grad = {}'.format(
                name, str(param.size()), str(param.requires_grad)))
        logging.info('*' * 75)
        logging.info("Model hype-parameter information...")
        for key, value in vars(args).items():
            logging.info('Hype-parameter\t{} = {}'.format(key, value))
        logging.info('*' * 75)
        logging.info('batch_size = {}'.format(args.batch_size))
        logging.info('projection_dim = {}'.format(args.project_dim))
        logging.info('learning_rate = {}'.format(args.learning_rate))
        logging.info('Start training...')
        train_all_steps(model=model,
                        optimizer=optimizer,
                        dev_dataloader=dev_data_loader,
                        device=device,
                        train_dataloader=train_data_loader,
                        tokenizer=tokenizer,
                        args=args)
        logging.info('Completed training in {:.4f} seconds'.format(time() -
                                                                   start_time))
        logging.info('Evaluating on Valid Dataset...')
        metric_dict = test_all_steps(model=model,
                                     tokenizer=tokenizer,
                                     test_data_loader=dev_data_loader,
                                     args=args)
        answer_type_acc = metric_dict['answer_type_acc']
        logging.info('*' * 75)
        logging.info(
            'Answer type prediction accuracy: {}'.format(answer_type_acc))
        logging.info('*' * 75)
        for key, value in metric_dict.items():
            if key.endswith('metrics'):
                logging.info('{} prediction'.format(key))
                log_metrics('Valid', 'final', value)
        logging.info('*' * 75)
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++
        model_save_path = save_check_point(model=model,
                                           optimizer=optimizer,
                                           step='all_step',
                                           loss=None,
                                           eval_metric=None,
                                           args=args)
        logging.info('Saving the mode in {}'.format(model_save_path))
示例#7
0
class base(object):
    def __init__(self, args):
        # initialize hyper-parameters
        self.data = args.data
        self.gan_type = args.gan_type
        self.d_depth = args.d_depth
        self.dowmsampling = args.dowmsampling
        self.gpu_counts = args.gpu_counts
        self.power = args.power
        self.batch_size = args.batch_size
        self.use_gpu = torch.cuda.is_available()
        self.u_depth = args.u_depth
        self.is_pretrained_unet = args.is_pretrained_unet
        self.pretrain_unet_path = args.pretrain_unet_path

        self.lr = args.lr
        self.debug = args.debug
        self.prefix = args.prefix
        self.interval = args.interval
        self.n_update_gan = args.n_update_gan
        self.epochs = args.epochs
        self.gamma = args.gamma
        self.beta1 = args.beta1

        self.training_strategies = args.training_strategies
        self.epoch_interval = 1 if self.debug else 50

        self.logger = Logger(add_prefix(self.prefix, 'tensorboard'))
        # normalize the images between [-1 and 1]
        self.mean = [0.5, 0.5, 0.5]
        self.std = [0.5, 0.5, 0.5]
        self.dataloader = self.get_dataloader()
        self.d = get_discriminator(self.gan_type, self.d_depth,
                                   self.dowmsampling)
        self.unet = self.get_unet()

        self.log_lst = []

        if self.use_gpu:
            self.unet = DataParallel(self.unet).cuda()
            self.d = DataParallel(self.d).cuda()
        else:
            raise RuntimeWarning('there is no gpu available.')
        self.save_init_paras()
        self.get_optimizer()
        self.save_hyperparameters(args)

    def save_hyperparameters(self, args):
        write(vars(args), add_prefix(self.prefix, 'para.txt'))
        print('save hyperparameters successfully.')

    def get_lr(self):
        lr = []
        for param_group in self.d_optimizer.param_groups:
            lr += [param_group['lr']]
        return lr[0]

    def restore(self, x):
        x = torch.squeeze(x)
        x = x.data.cpu()
        for t, m, s in zip(x, self.mean, self.std):
            t.mul_(s).add_(m)
        x = x.numpy()
        x = np.transpose(x, (1, 2, 0))
        x = np.clip(x * 255, 0, 255).astype(np.uint8)
        return x

    def get_dataloader(self):
        if self.data == './data/gan':
            print('load DR with size 128 successfully!!')
        else:
            raise ValueError("the parameter data must be in ['./data/gan']")
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(self.mean, self.std)])
        dataset = ConcatDataset(data_dir=self.data,
                                transform=transform,
                                alpha=self.power)
        data_loader = DataLoader(dataset,
                                 batch_size=self.batch_size,
                                 shuffle=True,
                                 num_workers=2,
                                 drop_last=True,
                                 pin_memory=True if self.use_gpu else False)
        return data_loader

    def get_unet(self):
        unet = UNet(3, depth=self.u_depth, in_channels=3)
        print(unet)
        print(
            'load uent with depth %d and downsampling will be performed for %d times!!'
            % (self.u_depth, self.u_depth - 1))
        if self.is_pretrained_unet:
            unet.load_state_dict(weight_to_cpu(self.pretrain_unet_path))
            print('load pretrained unet')
        return unet

    def main(self):
        assert hasattr(self, 'u_lr_scheduler') and hasattr(
            self, 'd_lr_scheduler')
        print('training start!')
        start_time = time.time()
        print(
            'd will be updated %d times while g will be updated for 1 time.' %
            self.n_update_gan)
        if self.interval % self.n_update_gan != 0:
            warnings.warn(
                "It's hyperparameter n_update_gan is divisible by hyperparameter interval"
            )
        for epoch in range(1, self.epochs + 1):
            self.u_lr_scheduler.step()
            self.d_lr_scheduler.step()

            self.train(epoch)
            if epoch % self.epoch_interval == 0:
                with torch.no_grad():
                    self.validate(epoch)
        with torch.no_grad():
            self.validate(self.epochs)

        total_ptime = time.time() - start_time
        if not self.debug:
            # note:relative path is based on the script u_d.py
            print('Training complete in {:.0f}m {:.0f}s'.format(
                total_ptime // 60, total_ptime % 60))

    def validate(self, epoch):
        """
        eval mode
        """
        real_data_score = []
        fake_data_score = []
        for i, (lesion_data, _, lesion_names, _, real_data, _, normal_names,
                _) in enumerate(self.dataloader):
            if i > 2:
                break
            if self.use_gpu:
                lesion_data, real_data = lesion_data.cuda(), real_data.cuda()
            phase = 'lesion_data'
            prefix_path = '%s/epoch_%d/%s' % (self.prefix, epoch, phase)
            lesion_output = self.d(self.unet(lesion_data))
            fake_data_score += list(
                lesion_output.squeeze().cpu().data.numpy().flatten())

            for idx in range(self.batch_size):
                single_image = lesion_data[idx:(idx + 1), :, :, :]
                single_name = lesion_names[idx]
                self.save_image(prefix_path, single_name, single_image)
                if self.debug:
                    break

            phase = 'normal_data'
            prefix_path = '%s/epoch_%d/%s' % (self.prefix, epoch, phase)
            normal_output = self.d(real_data)
            real_data_score += list(
                normal_output.squeeze().cpu().data.numpy().flatten())

            for idx in range(self.batch_size):
                single_image = real_data[idx:(idx + 1), :, :, :]
                single_name = normal_names[idx]
                self.save_image(prefix_path, single_name, single_image)
                if self.debug:
                    break

        prefix_path = '%s/epoch_%d' % (self.prefix, epoch)

        self.plot_hist('%s/score_distribution.png' % prefix_path,
                       real_data_score, fake_data_score)
        torch.save(self.unet.state_dict(), add_prefix(prefix_path, 'g.pkl'))
        torch.save(self.d.state_dict(), add_prefix(prefix_path, 'd.pkl'))
        print('save model parameters successfully when epoch=%d' % epoch)

    def save_image(self, saved_path, name, inputs):
        """
        save unet output as a form of image
        """
        if not os.path.exists(saved_path):
            os.makedirs(saved_path)
        output = self.unet(inputs)

        left = self.restore(inputs)
        right = self.restore(output)
        # The above two lines of code are wrong.To be precisely,errors will occur when the value of var left is less than
        # the value of var right.For example,left=217,right=220,then result is 253 after abs operation.
        diff = np.where(left > right, left - right,
                        right - left).clip(0, 255).astype(np.uint8)
        plt.figure(num='unet result', figsize=(8, 8))
        plt.subplot(2, 2, 1)
        plt.title('source image')
        plt.imshow(left)
        plt.axis('off')
        plt.subplot(2, 2, 2)
        plt.title('unet output')
        plt.imshow(right)
        plt.axis('off')
        plt.subplot(2, 2, 3)
        plt.imshow(rgb2gray(diff), cmap='jet')
        plt.colorbar(orientation='horizontal')
        plt.title('difference in heatmap')
        plt.axis('off')
        plt.subplot(2, 2, 4)
        plt.imshow(rgb2gray(diff.clip(0, 32)), cmap='jet')
        plt.colorbar(orientation='horizontal')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(add_prefix(saved_path, name))
        plt.close()

    def save_gradient(self, epoch, idx):
        """
        check bottom and top layer's gradient
        """
        if epoch % self.epoch_interval == 0 or epoch == 1:
            saved_path = '%s/gradient_epoch_%d' % (self.prefix, epoch)
            weights_top, weights_bottom = self.get_top_bottom_layer()
            weights_top, weights_bottom = list(
                weights_top.cpu().data.numpy().flatten()), list(
                    weights_bottom.cpu().data.numpy().flatten())
            self.plot_gradient(saved_path, 'weights_top_%d.png' % idx,
                               weights_top)
            self.plot_gradient(saved_path, 'weights_bottom_%d.png' % idx,
                               weights_bottom)

    def plot_gradient(self, saved_path, phase, weights):
        """
        display gradient distribution in histogram
        """
        bins = np.linspace(min(weights), max(weights), 60)
        plt.hist(weights,
                 bins=bins,
                 alpha=0.3,
                 label='gradient',
                 edgecolor='k')
        plt.legend(loc='upper right')
        if not os.path.exists(saved_path):
            os.makedirs(saved_path)
        plt.savefig('%s/%s' % (saved_path, phase))
        plt.close()

    def plot_hist(self, path, real_data, fake_data):
        bins = np.linspace(min(min(real_data), min(fake_data)),
                           max(max(real_data), max(fake_data)), 60)
        plt.hist(real_data,
                 bins=bins,
                 alpha=0.3,
                 label='real_score',
                 edgecolor='k')
        plt.hist(fake_data,
                 bins=bins,
                 alpha=0.3,
                 label='fake_score',
                 edgecolor='k')
        plt.legend(loc='upper right')
        plt.savefig(path)
        plt.close()

    def get_top_bottom_layer(self):
        """
        save gradient of top and bottom layers to double-check and analyse
        """
        layer_names = list(dict(self.d.named_parameters()).keys())
        for name in layer_names:
            if 'conv' in name and 'weight' in name and 'bn' not in name and 'fc' not in name \
                    and not name.endswith('_u') and not name.endswith('_v'):
                bottom = name
                break
        for name in layer_names[::-1]:
            if 'conv' in name and 'weight' in name and 'bn' not in name and 'fc' not in name \
                    and not name.endswith('_u') and not name.endswith('_v'):
                top = name
                break
        return dict(self.d.named_parameters())[top].grad, dict(
            self.d.named_parameters())[bottom].grad

    def train(self, epoch):
        pass

    def get_optimizer(self):
        pass

    def save_running_script(self, script_path):
        """
        save the main running script to get differences between scripts
        """
        copy(script_path, add_prefix(self.prefix, script_path.split('/')[-1]))

    def save_log(self):
        write_list(self.log_lst, add_prefix(self.prefix, 'log.txt'))
        print('save running log successfully')

    def save_init_paras(self):
        if not os.path.exists(self.prefix):
            os.makedirs(self.prefix)

        torch.save(self.unet.state_dict(),
                   add_prefix(self.prefix, 'init_g_para.pkl'))
        torch.save(self.d.state_dict(),
                   add_prefix(self.prefix, 'init_d_para.pkl'))
        print('save initial model parameters successfully')

    def load_pretrained_model(self):
        pass

    def load_config(self):
        pass
示例#8
0
class Solver():
    def __init__(self, config, channel_list):
        # Config - Model
        self.z_dim = config.z_dim
        self.channel_list = channel_list

        # Config - Training
        self.batch_size = config.batch_size
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.decay_ratio = config.decay_ratio
        self.decay_iter = config.decay_iter
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.n_critic = config.n_critic
        self.lambda_gp = config.lambda_gp
        self.max_iter = config.max_iter

        # Config - Test
        self.fixed_z = torch.rand(128, config.z_dim, 1, 1).to(dev)

        # Config - Path
        self.data_root = config.data_root
        self.log_root = config.log_root
        self.model_root = config.model_root
        self.sample_root = config.sample_root
        self.result_root = config.result_root

        # Config - Miscellanceous
        self.print_loss_iter = config.print_loss_iter
        self.save_image_iter = config.save_image_iter
        self.save_parameter_iter = config.save_parameter_iter
        self.save_log_iter = config.save_log_iter

        self.writer = SummaryWriter(self.log_root)

    def build_model(self):
        self.G = Generator(channel_list=self.channel_list)
        self.G_ema = Generator(channel_list=self.channel_list)
        self.D = Discriminator(channel_list=self.channel_list)

        self.G = DataParallel(self.G).to(dev)
        self.G_ema = DataParallel(self.G_ema).to(dev)
        self.D = DataParallel(self.D).to(dev)

        self.g_optimizer = torch.optim.Adam(params=self.G.parameters(),
                                            lr=self.g_lr,
                                            betas=[self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(params=self.D.parameters(),
                                            lr=self.d_lr,
                                            betas=[self.beta1, self.beta2])

        self.g_scheduler = lr_scheduler.StepLR(self.g_optimizer,
                                               step_size=self.decay_iter,
                                               gamma=self.decay_ratio)
        self.d_scheduler = lr_scheduler.StepLR(self.d_optimizer,
                                               step_size=self.decay_iter,
                                               gamma=self.decay_ratio)

        print("Print model G, D")
        print(self.G)
        print(self.D)

    def load_model(self, pkl_path, channel_list):
        ckpt = torch.load(pkl_path)

        self.G = Generator(channel_list=channel_list)
        self.G_ema = Generator(channel_list=channel_list)
        self.D = Discriminator(channel_list=channel_list)

        self.G = DataParallel(self.G).to(dev)
        self.G_ema = DataParallel(self.G_ema).to(dev)
        self.D = DataParallel(self.D).to(dev)

        self.G.load_state_dict(ckpt["G"])
        self.G_ema.load_state_dict(ckpt["G_ema"])
        self.D.load_state_dict(ckpt["D"])

    def save_model(self, iters, step):
        file_name = 'ckpt_%d_%d.pkl' % ((2 * (2**(step + 1)), iters))
        ckpt_path = os.path.join(self.model_root, file_name)
        ckpt = {
            'G': self.G.state_dict(),
            'G_ema': self.G_eam.state_dict(),
            'D': self.D.state_dict()
        }
        torch.save(ckpt, ckpt_path)

    def save_img(self, iters, fixed_z, step):
        img_path = os.path.join(self.sample_root,
                                "%d_%d.png" % (2 * (2**(step + 1)), iters))
        with torch.no_grad():
            generated_imgs = self.G_ema(fixed_z[:self.batch_size].to(dev),
                                        step, 1)
            save_image(
                make_grid(generated_imgs.cpu() / 2 + 1 / 2, nrow=4, padding=2),
                img_path)

    def reset_grad(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def lr_update(self):
        self.g_scheduler.step()
        self.d_scheduler.step()

    def set_phase(self, mode="train"):
        if mode == "train":
            self.G.train()
            self.G_ema.train()
            self.D.train()

        elif mode == "test":
            self.G.eval()
            self.G_ema.eval()
            self.D.eval()

    def exponential_moving_average(self, beta=0.999):
        with torch.no_grad():
            G_param_dict = dict(self.G.named_parameters())
            for name, g_ema_param in self.G_ema.named_parameters():
                g_param = G_param_dict[name]
                g_ema_param.copy_(beta * g_ema_param + (1. - beta) * g_param)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(dev)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm - 1)**2)

    def train(self):
        # build model
        self.build_model()

        for step in range(len(self.channel_list)):
            if step > 4:
                self.batch_size = self.batch_size // 2
            loader = data_loader(self.data_root,
                                 self.batch_size,
                                 img_size=2 * (2**(step + 1)))
            loader = iter(cycle(loader))

            if step == 0 or step == 1 or step == 2:
                self.max_iter = 20000
            elif step == 3 or step == 4 or step == 5:
                self.max_iter = 50000
            else:
                self.max_iter = 100000

            alpha = 0.0

            for iters in range(self.max_iter + 1):
                real_img = next(loader)
                real_img = real_img.to(dev)

                # ===============================================================#
                #                    1. Train the discriminator                  #
                # ===============================================================#
                self.set_phase(mode="train")
                self.reset_grad()

                # Compute loss with real images.
                d_real_out = self.D(real_img, step, alpha)
                d_loss_real = -d_real_out.mean()

                # Compute loss with face images.
                z = torch.rand(self.batch_size, self.z_dim, 1, 1).to(dev)
                fake_img = self.G(z, step, alpha)
                d_fake_out = self.D(fake_img.detach(), step, alpha)
                d_loss_fake = d_fake_out.mean()

                # Compute loss for gradient penalty.
                beta = torch.rand(self.batch_size, 1, 1, 1).to(dev)
                x_hat = (beta * real_img.data +
                         (1 - beta) * fake_img.data).requires_grad_(True)
                d_x_hat_out = self.D(x_hat, step, alpha)
                d_loss_gp = self.gradient_penalty(d_x_hat_out, x_hat)

                # Backward and optimize.
                d_loss = d_loss_real + d_loss_fake + self.lambda_gp * d_loss_gp
                d_loss.backward()
                self.d_optimizer.step()

                # ===============================================================#
                #                      2. Train the Generator                    #
                # ===============================================================#

                if (iters + 1) % self.n_critic == 0:
                    self.reset_grad()

                    # Compute loss with fake images.
                    fake_img = self.G(z, step, alpha)
                    d_fake_out = self.D(fake_img, step, alpha)
                    g_loss = -d_fake_out.mean()

                    # Backward and optimize.
                    g_loss.backward()
                    self.g_optimizer.step()

                # ===============================================================#
                #                   3. Save parameters and images                #
                # ===============================================================#
                # self.lr_update()
                torch.cuda.synchronize()
                alpha += 1 / (self.max_iter // 2)
                self.set_phase(mode="test")
                self.exponential_moving_average()

                # Print total loss
                if iters % self.print_loss_iter == 0:
                    print(
                        "Step : [%d/%d], Iter : [%d/%d], D_loss : [%.3f, %.3f, %.3f., %.3f], G_loss : %.3f"
                        %
                        (step, len(self.channel_list) - 1, iters,
                         self.max_iter, d_loss.item(), d_loss_real.item(),
                         d_loss_fake.item(), d_loss_gp.item(), g_loss.item()))

                # Save generated images.
                if iters % self.save_image_iter == 0:
                    self.save_img(iters, self.fixed_z, step)

                # Save the G and D parameters.
                if iters % self.save_parameter_iter == 0:
                    self.save_model(iters, step)

                # Save the logs on the tensorboard.
                if iters % self.save_log_iter == 0:
                    self.writer.add_scalar('g_loss/g_loss', g_loss.item(),
                                           iters)
                    self.writer.add_scalar('d_loss/d_loss_total',
                                           d_loss.item(), iters)
                    self.writer.add_scalar('d_loss/d_loss_real',
                                           d_loss_real.item(), iters)
                    self.writer.add_scalar('d_loss/d_loss_fake',
                                           d_loss_fake.item(), iters)
                    self.writer.add_scalar('d_loss/d_loss_gp',
                                           d_loss_gp.item(), iters)
示例#9
0
loader_params = {'batch_size': 96 // NUM_GPUS, 'num_gpus':NUM_GPUS, 'num_workers':num_workers}
train_loader = VCRLoader.from_dataset(train, **loader_params)
val_loader = VCRLoader.from_dataset(val, **loader_params)
test_loader = VCRLoader.from_dataset(test, **loader_params)

ARGS_RESET_EVERY = 100
print("Loading {} for {}".format(params['model'].get('type', 'WTF?'), 'rationales' if args.rationale else 'answer'), flush=True)
model = Model.from_params(vocab=train.vocab, params=params['model'])
for submodule in model.detector.backbone.modules():
    if isinstance(submodule, BatchNorm2d):
        submodule.track_running_stats = False
    for p in submodule.parameters():
        p.requires_grad = False

model = DataParallel(model).cuda() if NUM_GPUS > 1 else model.cuda()
optimizer = Optimizer.from_params([x for x in model.named_parameters() if x[1].requires_grad],
                                  params['trainer']['optimizer'])

lr_scheduler_params = params['trainer'].pop("learning_rate_scheduler", None)
scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params) if lr_scheduler_params else None

if os.path.exists(args.folder):
    print("Found folder! restoring", flush=True)
    start_epoch, val_metric_per_epoch = restore_checkpoint(model, optimizer, serialization_dir=args.folder,
                                                           learning_rate_scheduler=scheduler)
else:
    print("Making directories")
    os.makedirs(args.folder, exist_ok=True)
    start_epoch, val_metric_per_epoch = 0, []
    shutil.copy2(args.params, args.folder)
示例#10
0
    def train(self):
        torch.multiprocessing.set_sharing_strategy('file_system')

        path = self.args.data_path
        label_file = self.args.label_path
        self.logger.info('original train process')
        time_stamp_launch = time.strftime('%Y%m%d') + '-' + time.strftime(
            '%H%M')
        self.logger.info(path.split('/')[-2] + time_stamp_launch)
        best_acc = 0
        model_root = './model_' + path.split('/')[-2]
        if not os.path.exists(model_root):
            os.mkdir(model_root)
        cuda = True
        cudnn.benchmark = True
        batch_size = self.args.batchsize
        batch_size_g = batch_size * 2
        image_size = (224, 224)
        num_cls = self.args.num_class

        self.generator_epoch = self.args.generator_epoch
        self.warm_epoch = 10
        n_epoch = self.args.max_epoch
        weight_decay = 1e-6
        momentum = 0.9

        manual_seed = random.randint(1, 10000)
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)

        #######################
        # load data           #
        #######################
        target_train = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop((224, 224)),
            transforms.RandomHorizontalFlip(),
            AutoAugment(),
            transforms.ToTensor(),
            transforms.Normalize((0.435, 0.418, 0.396),
                                 (0.284, 0.308, 0.335)),  # grayscale mean/std
        ])

        dataset_train = visDataset_target(path,
                                          label_file,
                                          train=True,
                                          transform=target_train)

        dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=3)
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.435, 0.418, 0.396),
                                 (0.284, 0.308, 0.335)),  # grayscale mean/std
        ])

        test_dataset = visDataset_target(path,
                                         label_file,
                                         train=True,
                                         transform=transform_test)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=3)

        #####################
        #  load model       #
        #####################
        self.lemniscate = LinearAverage(2048, test_dataset.__len__(), 0.05,
                                        0.00).cuda()
        self.elr_loss = elr_loss(num_examp=test_dataset.__len__(),
                                 num_classes=12).cuda()

        generator = generator_fea_deconv(class_num=num_cls)

        discriminator = Discriminator_fea()
        source_net = torch.load(self.args.source_model_path)
        source_classifier = Classifier(num_classes=num_cls)
        fea_contrastor = contrastor()

        # load pre-trained source classifier
        fc_dict = source_classifier.state_dict()
        pre_dict = source_net.state_dict()
        pre_dict = {k: v for k, v in pre_dict.items() if k in fc_dict}
        fc_dict.update(pre_dict)
        source_classifier.load_state_dict(fc_dict)

        generator = DataParallel(generator, device_ids=[0, 1])
        discriminator = DataParallel(discriminator, device_ids=[0, 1])
        fea_contrastor = DataParallel(fea_contrastor, device_ids=[0, 1])
        source_net = DataParallel(source_net, device_ids=[0, 1])
        source_classifier = DataParallel(source_classifier, device_ids=[0, 1])
        source_classifier.eval()

        for p in generator.parameters():
            p.requires_grad = True
        for p in source_net.parameters():
            p.requires_grad = True

        # freezing the source classifier
        for name, value in source_net.named_parameters():
            if name[:9] == 'module.fc':
                value.requires_grad = False

        # setup optimizer
        params = filter(lambda p: p.requires_grad, source_net.parameters())
        discriminator_group = []
        for k, v in discriminator.named_parameters():
            discriminator_group += [{'params': v, 'lr': self.lr * 3}]

        model_params = []
        for v in params:
            model_params += [{'params': v, 'lr': self.lr}]

        contrastor_para = []
        for k, v in fea_contrastor.named_parameters():
            contrastor_para += [{'params': v, 'lr': self.lr * 5}]

        #####################
        # setup optimizer   #
        #####################

        # only train the extractor
        optimizer = optim.SGD(model_params + discriminator_group +
                              contrastor_para,
                              momentum=momentum,
                              weight_decay=weight_decay)
        optimizer_g = optim.SGD(generator.parameters(),
                                lr=self.lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

        loss_gen_ce = torch.nn.CrossEntropyLoss()

        if cuda:
            source_net = source_net.cuda()
            generator = generator.cuda()
            discriminator = discriminator.cuda()
            fea_contrastor = fea_contrastor.cuda()
            loss_gen_ce = loss_gen_ce.cuda()
            source_classifier = source_classifier.cuda()

        #############################
        # training network          #
        #############################

        len_dataloader = len(dataloader_train)
        self.logger.info('the step of one epoch: ' + str(len_dataloader))

        current_step = 0
        for epoch in range(n_epoch):
            source_net.train()
            discriminator.train()
            fea_contrastor.train()

            data_train_iter = iter(dataloader_train)

            if epoch < self.generator_epoch:
                generator.train()
                self.train_prototype_generator(epoch, batch_size_g, num_cls,
                                               optimizer_g, generator,
                                               source_classifier, loss_gen_ce)

            if epoch >= self.generator_epoch:
                if epoch == self.generator_epoch:
                    torch.save(
                        generator, model_root + '/generator_' +
                        path.split('/')[-2] + '.pkl')

                # prototype generation
                generator.eval()
                z = Variable(torch.rand(self.args.num_class * 2, 100)).cuda()

                # Get labels ranging from 0 to n_classes for n rows
                label_t = torch.linspace(0, num_cls - 1, steps=num_cls).long()
                for ti in range(self.args.num_class * 2 // num_cls - 1):
                    label_t = torch.cat([
                        label_t,
                        torch.linspace(0, num_cls - 1, steps=num_cls).long()
                    ])
                labels = Variable(label_t).cuda()
                z = z.contiguous()
                labels = labels.contiguous()
                images = generator(z, labels)

                self.alpha = 0.9 - (epoch - self.generator_epoch) / (
                    n_epoch - self.generator_epoch) * 0.2

                # obtain the target pseudo label and confidence weight
                pseudo_label, pseudo_label_acc, all_indx, confidence_weight = self.obtain_pseudo_label_and_confidence_weight(
                    test_loader, source_net)

                i = 0
                while i < len_dataloader:
                    ###################################
                    #        prototype adaptation         #
                    ###################################
                    p = float(i +
                              (epoch - self.generator_epoch) * len_dataloader
                              ) / (n_epoch -
                                   self.generator_epoch) / len_dataloader
                    self.p = 2. / (1. + np.exp(-10 * p)) - 1
                    data_target_train = data_train_iter.next()
                    s_img, s_label, s_indx = data_target_train

                    batch_size_s = len(s_label)

                    input_img_s = torch.FloatTensor(batch_size_s, 3,
                                                    image_size[0],
                                                    image_size[1])
                    class_label_s = torch.LongTensor(batch_size_s)

                    if cuda:
                        s_img = s_img.cuda()
                        s_label = s_label.cuda()
                        input_img_s = input_img_s.cuda()
                        class_label_s = class_label_s.cuda()

                    input_img_s.resize_as_(s_img).copy_(s_img)
                    class_label_s.resize_as_(s_label).copy_(s_label)
                    target_inputv_img = Variable(input_img_s)
                    target_classv_label = Variable(class_label_s)

                    # learning rate decay
                    optimizer = self.exp_lr_scheduler(optimizer=optimizer,
                                                      step=current_step)

                    loss, contrastive_loss = self.adaptation_step(
                        target_inputv_img, pseudo_label, images.detach(),
                        labels, s_indx.numpy(), source_net, discriminator,
                        fea_contrastor, optimizer, epoch,
                        confidence_weight.float())

                    # visualization on tensorboard
                    self.writer.add_scalar('contrastive_loss',
                                           contrastive_loss,
                                           global_step=current_step)
                    self.writer.add_scalar('overall_loss',
                                           loss,
                                           global_step=current_step)
                    self.writer.add_scalar('pseudo_label_acc',
                                           pseudo_label_acc,
                                           global_step=current_step)

                    i += 1
                    current_step += 1

                self.logger.info('epoch: %d' % epoch)
                self.logger.info('contrastive_loss: %f' % (contrastive_loss))
                self.logger.info('loss: %f' % loss)
                accu, ac_list = val_pclass(source_net, test_loader)
                self.writer.add_scalar('test_acc',
                                       accu,
                                       global_step=current_step)
                self.logger.info(ac_list)
                if accu >= best_acc:
                    self.logger.info('saving the best model!')
                    torch.save(
                        source_net, model_root + '/' + time_stamp_launch +
                        '_best_model_' + path.split('/')[-2] + '.pkl')
                    best_acc = accu

                self.logger.info('acc is : %.04f, best acc is : %.04f' %
                                 (accu, best_acc))
                self.logger.info(
                    '================================================')

        self.logger.info('training done! ! !')