コード例 #1
0
    def validate(self):
        self.model.eval()
        with torch.no_grad():
            for i, (image, _) in enumerate(self.target_loader):
                image = image.cuda(non_blocking=True).float()

                # forward
                output_mask = self.model(image)

                if i % self.args.vis_freq_inval == 0:
                    image = image[:self.args.vis_batch]
                    if self.args.data_modality == 'oct':
                        # OCT: {0, 1, ..., 11}
                        # gt: BWH
                        # model output: BCWH (C=12)

                        # BCWH -> BWH -> B1WH
                        output_mask = F.log_softmax(output_mask, dim=1)
                        _, output_mask = torch.max(output_mask, dim=1)
                        output_mask = output_mask.float().unsqueeze(dim=1)

                        # {0, 1, ..., 11} -> (0, 1)
                        output_mask = torch.clamp(
                            output_mask[:self.args.vis_batch] / 11, 0, 1)
                    else:
                        # fundus: {0, 1}, B1WH
                        output_mask = output_mask[:self.args.vis_batch]

                    save_images = torch.cat([image, output_mask], dim=0)
                    output_save = os.path.join(self.args.output_root,
                                               self.args.project, 'output',
                                               self.args.version, 'val')
                    if not os.path.exists(output_save):
                        os.makedirs(output_save)
                    tv.utils.save_image(save_images,
                                        os.path.join(output_save,
                                                     '{}.png'.format(i)),
                                        nrow=self.args.vis_batch)

                    # print('val: [Batch {}/{}]'.format(i, self.target_loader.__len__()))

        save_ckpt(version=self.args.version,
                  state={
                      'epoch': self.epoch,
                      'state_dict_G': self.model.model_G.state_dict(),
                      'state_dict_D': self.model.model_D.state_dict(),
                  },
                  epoch=self.epoch,
                  args=self.args)
        print('Save ckpt successfully!')
コード例 #2
0
def main(opt):
    start_epoch = 0
    err_best = 10000
    lr_now = opt.lr
    is_cuda = torch.cuda.is_available()

    # save option in log
    script_name = os.path.basename(__file__).split('.')[0]
    script_name = script_name + '_3D_in{:d}_out{:d}_dct_n_{:d}'.format(
        opt.input_n, opt.output_n, opt.dct_n)

    # create model
    print(">>> creating model")
    input_n = opt.input_n
    output_n = opt.output_n
    dct_n = opt.dct_n
    sample_rate = opt.sample_rate

    model = nnmodel.GCN(input_feature=dct_n,
                        hidden_feature=opt.linear_size,
                        p_dropout=opt.dropout,
                        num_stage=opt.num_stage,
                        node_n=66)

    if is_cuda:
        model.cuda()

    print(">>> total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    if opt.is_load:
        model_path_len = 'checkpoint/test/' + 'ckpt_' + script_name + '_last.pth.tar'
        print(">>> loading ckpt len from '{}'".format(model_path_len))
        if is_cuda:
            ckpt = torch.load(model_path_len)
        else:
            ckpt = torch.load(model_path_len, map_location='cpu')
        start_epoch = ckpt['epoch']
        err_best = ckpt['err']
        lr_now = ckpt['lr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print(">>> ckpt len loaded (epoch: {} | err: {})".format(
            start_epoch, err_best))

    # data loading
    print(">>> loading data")
    train_dataset = H36motion3D(path_to_data=opt.data_dir,
                                actions='all',
                                input_n=input_n,
                                output_n=output_n,
                                split=0,
                                dct_used=dct_n,
                                sample_rate=sample_rate)

    acts = data_utils.define_actions('all')
    test_data = dict()
    for act in acts:
        test_dataset = H36motion3D(path_to_data=opt.data_dir,
                                   actions=act,
                                   input_n=input_n,
                                   output_n=output_n,
                                   split=1,
                                   sample_rate=sample_rate,
                                   dct_used=dct_n)
        test_data[act] = DataLoader(dataset=test_dataset,
                                    batch_size=opt.test_batch,
                                    shuffle=False,
                                    num_workers=opt.job,
                                    pin_memory=True)
    val_dataset = H36motion3D(path_to_data=opt.data_dir,
                              actions='all',
                              input_n=input_n,
                              output_n=output_n,
                              split=2,
                              dct_used=dct_n,
                              sample_rate=sample_rate)

    # load dadasets for training
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=opt.train_batch,
                              shuffle=True,
                              num_workers=opt.job,
                              pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=opt.test_batch,
                            shuffle=False,
                            num_workers=opt.job,
                            pin_memory=True)
    print(">>> data loaded !")
    print(">>> train data {}".format(train_dataset.__len__()))
    print(">>> test data {}".format(test_dataset.__len__()))
    print(">>> validation data {}".format(val_dataset.__len__()))

    for epoch in range(start_epoch, opt.epochs):

        if (epoch + 1) % opt.lr_decay == 0:
            lr_now = utils.lr_decay(optimizer, lr_now, opt.lr_gamma)

        print('==========================')
        print('>>> epoch: {} | lr: {:.5f}'.format(epoch + 1, lr_now))
        ret_log = np.array([epoch + 1])
        head = np.array(['epoch'])
        # per epoch
        lr_now, t_l = train(train_loader,
                            model,
                            optimizer,
                            lr_now=lr_now,
                            max_norm=opt.max_norm,
                            is_cuda=is_cuda,
                            dim_used=train_dataset.dim_used,
                            dct_n=dct_n)
        ret_log = np.append(ret_log, [lr_now, t_l])
        head = np.append(head, ['lr', 't_l'])

        v_3d = val(val_loader,
                   model,
                   is_cuda=is_cuda,
                   dim_used=train_dataset.dim_used,
                   dct_n=dct_n)

        ret_log = np.append(ret_log, [v_3d])
        head = np.append(head, ['v_3d'])

        test_3d_temp = np.array([])
        test_3d_head = np.array([])
        for act in acts:
            test_l, test_3d = test(test_data[act],
                                   model,
                                   input_n=input_n,
                                   output_n=output_n,
                                   is_cuda=is_cuda,
                                   dim_used=train_dataset.dim_used,
                                   dct_n=dct_n)
            # ret_log = np.append(ret_log, test_l)
            ret_log = np.append(ret_log, test_3d)
            head = np.append(
                head,
                [act + '3d80', act + '3d160', act + '3d320', act + '3d400'])
            if output_n > 10:
                head = np.append(head, [act + '3d560', act + '3d1000'])
        ret_log = np.append(ret_log, test_3d_temp)
        head = np.append(head, test_3d_head)

        # update log file and save checkpoint
        df = pd.DataFrame(np.expand_dims(ret_log, axis=0))
        if epoch == start_epoch:
            df.to_csv(opt.ckpt + '/' + script_name + '.csv',
                      header=head,
                      index=False)
        else:
            with open(opt.ckpt + '/' + script_name + '.csv', 'a') as f:
                df.to_csv(f, header=False, index=False)
        if not np.isnan(v_3d):
            is_best = v_3d < err_best
            err_best = min(v_3d, err_best)
        else:
            is_best = False
        file_name = [
            'ckpt_' + script_name + '_best.pth.tar',
            'ckpt_' + script_name + '_last.pth.tar'
        ]
        utils.save_ckpt(
            {
                'epoch': epoch + 1,
                'lr': lr_now,
                'err': test_3d[0],
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
            ckpt_path=opt.ckpt,
            is_best=is_best,
            file_name=file_name)
コード例 #3
0
def train():
    train_dirs = [train_dir for train_dir in [args.train_dir, args.train_dir2] if train_dir is not None]
    train_data = ACNet_data.FreiburgForest(
        transform=transforms.Compose([
            ACNet_data.ScaleNorm(),
            # ACNet_data.RandomRotate((-13, 13)),
            # ACNet_data.RandomSkew((-0.05, 0.10)),
            ACNet_data.RandomScale((1.0, 1.4)),
            ACNet_data.RandomHSV((0.9, 1.1),
                                 (0.9, 1.1),
                                 (25, 25)),
            ACNet_data.RandomCrop(image_h, image_w),
            ACNet_data.RandomFlip(),
            ACNet_data.ToTensor(),
            ACNet_data.Normalize()
        ]),
        data_dirs=train_dirs,
        modal1_name=args.modal1,
        modal2_name=args.modal2,
    )

    valid_dirs = [valid_dir for valid_dir in [args.valid_dir, args.valid_dir2] if valid_dir is not None]
    valid_data = ACNet_data.FreiburgForest(
        transform=transforms.Compose([
            ACNet_data.ScaleNorm(),
            ACNet_data.ToTensor(),
            ACNet_data.Normalize()
        ]),
        data_dirs=valid_dirs,
        modal1_name=args.modal1,
        modal2_name=args.modal2,
    )

    '''
    # Split dataset into training and validation
    dataset_length = len(data)
    valid_split = 0.05  # tiny split due to the small size of the dataset
    valid_length = int(valid_split * dataset_length)
    train_length = dataset_length - valid_length
    train_data, valid_data = torch.utils.data.random_split(data, [train_length, valid_length])
    '''

    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=False)
    valid_loader = DataLoader(valid_data, batch_size=args.batch_size * 3, shuffle=False,
                              num_workers=1, pin_memory=False)

    # Initialize model
    if args.last_ckpt:
        model = ACNet_models_V1.ACNet(num_class=5, pretrained=False)
    else:
        model = ACNet_models_V1.ACNet(num_class=5, pretrained=True)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model.train()
    model.to(device)

    # Initialize criterion, optimizer and scheduler
    criterion = utils.CrossEntropyLoss2d(weight=freiburgforest_frq)
    criterion.to(device)

    # TODO: try with different optimizers and schedulers (CyclicLR exp_range for example)
    # TODO: try with a smaller LR (currently loss decay is too steep and then doesn't change)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum, weight_decay=args.weight_decay)
    # lr_decay_lambda = lambda epoch: args.lr_decay_rate ** (epoch // args.lr_epoch_per_decay)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_decay_lambda)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=int(np.ceil(args.epochs / 7)), T_mult=2, eta_min=8e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=5e-4)
    global_step = 0

    # TODO: add early stop to avoid overfitting

    # Continue training from previous checkpoint
    if args.last_ckpt:
        global_step, args.start_epoch = utils.load_ckpt(model, optimizer, scheduler, args.last_ckpt, device)

    writer = SummaryWriter(args.summary_dir)
    losses = []
    for epoch in tqdm(range(int(args.start_epoch), args.epochs)):

        if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch:
            utils.save_ckpt(args.ckpt_dir, model, optimizer, scheduler, global_step, epoch)

        for batch_idx, sample in enumerate(train_loader):
            modal1, modal2 = sample['modal1'].to(device), sample['modal2'].to(device)
            target_scales = [sample[s].to(device) for s in ['label', 'label2', 'label3', 'label4', 'label5']]

            optimizer.zero_grad()
            pred_scales = model(modal1, modal2, args.checkpoint)
            loss = criterion(pred_scales, target_scales)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            global_step += 1
            if global_step % args.print_freq == 0 or global_step == 1:

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param.detach().cpu().numpy(), global_step, bins='doane')

                grid_image = make_grid(modal1[:3].detach().cpu(), 3, normalize=False)
                writer.add_image('Modal1', grid_image, global_step)
                grid_image = make_grid(modal2[:3].detach().cpu(), 3, normalize=False)
                writer.add_image('Modal2', grid_image, global_step)
                grid_image = make_grid(utils.color_label(torch.argmax(pred_scales[0][:3], 1) + 1), 3, normalize=True,
                                       range=(0, 255))
                writer.add_image('Prediction', grid_image, global_step)
                grid_image = make_grid(utils.color_label(target_scales[0][:3]), 3, normalize=True, range=(0, 255))
                writer.add_image('GroundTruth', grid_image, global_step)
                writer.add_scalar('Loss', loss.item(), global_step=global_step)
                writer.add_scalar('Loss average', sum(losses) / len(losses), global_step=global_step)
                writer.add_scalar('Learning rate', scheduler.get_last_lr()[0], global_step=global_step)

                # Compute validation metrics
                with torch.no_grad():
                    model.eval()

                    losses_val = []
                    acc_list = []
                    iou_list = []
                    for sample_val in valid_loader:
                        modal1_val, modal2_val = sample_val['modal1'].to(device), sample_val['modal2'].to(device)
                        target_val = sample_val['label'].to(device)
                        pred_val = model(modal1_val, modal2_val)

                        losses_val.append(criterion([pred_val], [target_val]).item())
                        acc_list.append(utils.accuracy(
                            (torch.argmax(pred_val, 1) + 1).detach().cpu().numpy().astype(int),
                            target_val.detach().cpu().numpy().astype(int))[0])
                        iou_list.append(utils.compute_IoU(
                            y_pred=(torch.argmax(pred_val, 1) + 1).detach().cpu().numpy().astype(int),
                            y_true=target_val.detach().cpu().numpy().astype(int),
                            num_classes=5
                        ))

                    writer.add_scalar('Loss validation', sum(losses_val) / len(losses_val), global_step=global_step)
                    writer.add_scalar('Accuracy', sum(acc_list) / len(acc_list), global_step=global_step)
                    iou = np.mean(np.stack(iou_list, axis=0), axis=0)
                    writer.add_scalar('IoU_Road', iou[0], global_step=global_step)
                    writer.add_scalar('IoU_Grass', iou[1], global_step=global_step)
                    writer.add_scalar('IoU_Vegetation', iou[2], global_step=global_step)
                    writer.add_scalar('IoU_Sky', iou[3], global_step=global_step)
                    writer.add_scalar('IoU_Obstacle', iou[4], global_step=global_step)
                    writer.add_scalar('mIoU', np.mean(iou), global_step=global_step)

                    model.train()

                losses = []

        scheduler.step()

    utils.save_ckpt(args.ckpt_dir, model, optimizer, scheduler, global_step, args.epochs)
    print("Training completed ")
コード例 #4
0
def train_seq2seq(model, dataloaders):
    train_dataloader, dev_dataloader, test_dataloader = dataloaders
    criterion = nn.CrossEntropyLoss(ignore_index=constant.pad_idx)
    if constant.optim == 'Adam':
        opt = torch.optim.Adam(model.parameters(), lr=constant.lr)
    elif constant.optim == 'SGD':
        opt = torch.optim.SGD(model.parameters(), lr=constant.lr)
    else:
        print("Optim is not defined")
        exit(1)

    start_epoch = 1
    if constant.restore:
        model, opt, start_epoch = load_ckpt(model, opt, constant.restore_path)

    if constant.USE_CUDA:
        model.cuda()
        if constant.embeddings_cpu:
            model.encoder.embedding.cpu()

    best_dev = 10000
    best_path = ''
    patience = 3
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,
                                                           'min',
                                                           factor=0.5,
                                                           patience=0,
                                                           min_lr=1e-6)

    try:
        for e in range(start_epoch, constant.epochs):
            model.train()
            loss_log = []
            ppl_log = []

            if constant.grid_search:
                pbar = enumerate(train_dataloader)
            else:
                pbar = tqdm(enumerate(train_dataloader),
                            total=len(train_dataloader))

            for b, (dialogs, lens, targets, _, _, _, _, _, _) in pbar:
                if len(train_dataloader) % (b + 1) == 10:
                    torch.cuda.empty_cache()
                opt.zero_grad()
                try:
                    batch_size, max_target_len = targets.shape
                    probs = model(dialogs, lens, targets)

                    # Masked CEL trick: Reshape probs to (B*L, V) and targets to (B*L,) and ignore pad idx
                    probs = probs.transpose(0, 1).contiguous().view(
                        batch_size * max_target_len, -1)
                    targets = targets.contiguous().view(batch_size *
                                                        max_target_len)
                    loss = criterion(probs, targets)
                    # if constant.embeddings_cpu and constant.USE_CUDA:
                    #     targets = targets.cuda()
                    #     target_lens = target_lens.cuda()
                    # loss = masked_cross_entropy(probs.transpose(0, 1).contiguous(), targets.contiguous(), target_lens)
                    loss.backward()
                    opt.step()

                    ## logging
                    loss_log.append(loss.item())
                    ppl_log.append(math.exp(loss_log[-1]))
                    if not constant.grid_search:
                        pbar.set_description(
                            "(Epoch {}) TRAIN LOSS:{:.4f} TRAIN PPL:{:.1f}".
                            format(e, np.mean(loss_log), np.mean(ppl_log)))
                except RuntimeError as err:
                    if 'out of memory' in str(err):
                        print('| WARNING: ran out of memory, skipping batch')
                        torch.cuda.empty_cache()
                    else:
                        raise err

            ## LOG
            dev_loss, dev_ppl = eval_seq2seq(model, dev_dataloader, bleu=False)

            print("(Epoch {}) DEV LOSS: {:.4f}, DEV PPL: {:.1f}".format(
                e, dev_loss, dev_ppl))

            scheduler.step(dev_loss)
            if (dev_loss < best_dev):
                best_dev = dev_loss
                # save best model
                path = 'trained/data-{}.task-seq2seq.lr-{}.emb-{}.D-{}.H-{}.attn-{}.bi-{}.parse-{}.loss-{}'  # lr.embedding.D.H.attn.bi.parse.metric
                path = path.format(constant.data, constant.lr,
                                   constant.embedding, constant.D, constant.H,
                                   constant.attn, constant.bi, constant.parse,
                                   best_dev)
                if constant.topk:
                    path += '.topk-{}.tau-{}'.format(constant.topk_size,
                                                     constant.tau)
                if constant.grid_search:
                    path += '.grid'
                best_path = save_model(model, 'loss', best_dev, path)
                patience = 3
            else:
                patience -= 1
            if patience == 0: break
            if best_dev == 0.0: break

    except KeyboardInterrupt:
        if not constant.grid_search:
            print("KEYBOARD INTERRUPT: Save CKPT and Eval")
            save = True if input('Save ckpt? (y/n)\t') in [
                'y', 'Y', 'yes', 'Yes'
            ] else False
            if save:
                save_path = save_ckpt(model, opt, e)
                print("Saved CKPT path: ", save_path)
            # ask if eval
            do_eval = True if input('Proceed with eval? (y/n)\t') in [
                'y', 'Y', 'yes', 'Yes'
            ] else False
            if do_eval:
                dev_loss, dev_ppl, dev_bleu, dev_bleus = eval_seq2seq(
                    model, dev_dataloader, bleu=True, beam=constant.beam)
                print("DEV LOSS: {:.4f}, DEV PPL: {:.1f}, DEV BLEU: {:.4f}".
                      format(dev_loss, dev_ppl, dev_bleu))
                print(
                    "BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}"
                    .format(dev_bleus[0], dev_bleus[1], dev_bleus[2],
                            dev_bleus[3]))
        exit(1)

    # load and report best model on test
    torch.cuda.empty_cache()
    model = load_model(model, best_path)
    if constant.USE_CUDA:
        model.cuda()

    # train_loss, train_ppl, train_bleu, train_bleus = eval_seq2seq(model, train_dataloader, bleu=True, beam=constant.beam)
    dev_loss, dev_ppl, dev_bleu, dev_bleus = eval_seq2seq(model,
                                                          dev_dataloader,
                                                          bleu=True,
                                                          beam=constant.beam)
    test_loss, test_ppl, test_bleu, test_bleus = eval_seq2seq(
        model, test_dataloader, bleu=True, beam=constant.beam)

    # print("BEST TRAIN LOSS: {:.4f}, TRAIN PPL: {:.1f}, TRAIN BLEU: {:.4f}".format(train_loss, train_ppl, train_bleu))
    # print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".format(train_bleus[0], train_bleus[1], train_bleus[2], train_bleus[3]))

    print("BEST DEV LOSS: {:.4f}, DEV PPL: {:.1f}, DEV BLEU: {:.4f}".format(
        dev_loss, dev_ppl, dev_bleu))
    print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".
          format(dev_bleus[0], dev_bleus[1], dev_bleus[2], dev_bleus[3]))

    print("BEST TEST LOSS: {:.4f}, TEST PPL: {:.1f}, TEST BLEU: {:.4f}".format(
        test_loss, test_ppl, test_bleu))
    print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".
          format(test_bleus[0], test_bleus[1], test_bleus[2], test_bleus[3]))
コード例 #5
0
ファイル: train.py プロジェクト: lilujunai/person-reid
def main():
  cfg = Config()

  # Redirect logs to both console and file.
  if cfg.log_to_file:
    ReDirectSTD(cfg.stdout_file, 'stdout', False)
    ReDirectSTD(cfg.stderr_file, 'stderr', False)

  # Lazily create SummaryWriter
  writer = None

  TVT, TMO = set_devices(cfg.sys_device_ids)

  if cfg.seed is not None:
    set_seed(cfg.seed)

  # Dump the configurations to log.
  import pprint
  print('-' * 60)
  print('cfg.__dict__')
  pprint.pprint(cfg.__dict__)
  print('-' * 60)

  ###########
  # Dataset #
  ###########

  if not cfg.only_test:
    train_set = create_dataset(**cfg.train_set_kwargs)
    # The combined dataset does not provide val set currently.
    val_set = None if cfg.dataset == 'combined' else create_dataset(**cfg.val_set_kwargs)

  test_sets = []
  test_set_names = []
  if cfg.dataset == 'combined':
    for name in ['market1501', 'cuhk03', 'duke']:
      cfg.test_set_kwargs['name'] = name
      test_sets.append(create_dataset(**cfg.test_set_kwargs))
      test_set_names.append(name)
  else:
    test_sets.append(create_dataset(**cfg.test_set_kwargs))
    test_set_names.append(cfg.dataset)

  ###########
  # Models  #
  ###########

  if cfg.only_test:
    model = Model(cfg.net, pretrained=False, last_conv_stride=cfg.last_conv_stride)
  else:
    model = Model(cfg.net, path_to_predefined=cfg.net_pretrained_path, last_conv_stride=cfg.last_conv_stride) # This is a ShuffleNet Network. Model(last_conv_stride=cfg.last_conv_stride)
  
  #############################
  # Criteria and Optimizers   #
  #############################

  tri_loss = TripletLoss(margin=cfg.margin)

  optimizer = optim.Adam(model.parameters(),
                         lr=cfg.base_lr,
                         weight_decay=cfg.weight_decay)
  #optimizer = optimizers.FusedAdam(model.parameters(),
  #                      lr=cfg.base_lr,
  #                      weight_decay=cfg.weight_decay)

  #optimizer = torch.optim.SGD(model.parameters(), cfg.base_lr,
  #                            nesterov=True,
  #                            momentum=cfg.momentum,
  #                            weight_decay=cfg.weight_decay)

  model.cuda()
  model, optimizer = amp.initialize(model, optimizer,
                                    opt_level=cfg.opt_level,
                                    keep_batchnorm_fp32=cfg.keep_batchnorm_fp32,
                                    #loss_scale=cfg.loss_scale
                                    )


  amp.init() # Register function

  # Bind them together just to save some codes in the following usage.
  modules_optims = [model, optimizer]

# Model wrapper
  model_w = DataParallel(model)


  ################################
  # May Resume Models and Optims #
  ################################

  if cfg.resume:
    resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

  # May Transfer Models and Optims to Specified Device. Transferring optimizer
  # is to cope with the case when you load the checkpoint to a new device.
  TMO(modules_optims)

  ########
  # Test #
  ########

  def test(load_model_weight=False):
    if load_model_weight:
      if cfg.model_weight_file != '':
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.model_weight_file, map_location=map_location)
        load_state_dict(model, sd)
        print('Loaded model weights from {}'.format(cfg.model_weight_file))
      else:
        load_ckpt(modules_optims, cfg.ckpt_file)

    for test_set, name in zip(test_sets, test_set_names):
      feature_map = ExtractFeature(model_w, TVT)
      test_set.set_feat_func(feature_map)
      print('\n=========> Test on dataset: {} <=========\n'.format(name))
      test_set.eval(
        normalize_feat=cfg.normalize_feature,
        verbose=True)

  def validate():
    if val_set.extract_feat_func is None:
      feature_map = ExtractFeature(model_w, TVT)
      val_set.set_feat_func(feature_map)
    print('\n=========> Test on validation set <=========\n')
    mAP, cmc_scores, _, _ = val_set.eval(
      normalize_feat=cfg.normalize_feature,
      to_re_rank=False,
      verbose=False)
    print()
    return mAP, cmc_scores[0]

  if cfg.only_test:
    test(load_model_weight=True)
    return

  ############
  # Training #
  ############

  start_ep = resume_ep if cfg.resume else 0

  for ep in range(start_ep, cfg.total_epochs):

    # Adjust Learning Rate
    if cfg.lr_decay_type == 'exp':
      adjust_lr_exp(
        optimizer,
        cfg.base_lr,
        ep + 1,
        cfg.total_epochs,
        cfg.exp_decay_at_epoch)
    else:
      adjust_lr_staircase(
        optimizer,
        cfg.base_lr,
        ep + 1,
        cfg.staircase_decay_at_epochs,
        cfg.staircase_decay_multiply_factor)

    may_set_mode(modules_optims, 'train')

    # For recording precision, satisfying margin, etc
    prec_meter = AverageMeter()
    sm_meter = AverageMeter()
    dist_ap_meter = AverageMeter()
    dist_an_meter = AverageMeter()
    loss_meter = AverageMeter()

    ep_st = time.time()
    step = 0
    epoch_done = False
    while not epoch_done:

      step += 1
      step_st = time.time()

      ims, im_names, labels, mirrored, epoch_done = train_set.next_batch()

      ims_var = Variable(TVT(torch.from_numpy(ims).float()))

      labels_t = TVT(torch.from_numpy(labels).long())

      feat = model_w(ims_var)

      loss, p_inds, n_inds, dist_ap, dist_an, dist_mat = global_loss(
        tri_loss, feat, labels_t,
        normalize_feature=cfg.normalize_feature)

      optimizer.zero_grad()
      
      with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
      
      optimizer.step()

      ############
      # Step Log #
      ############

      # precision
      prec = (dist_an > dist_ap).data.float().mean()
      # the proportion of triplets that satisfy margin
      sm = (dist_an > dist_ap + cfg.margin).data.float().mean()
      # average (anchor, positive) distance
      d_ap = dist_ap.data.mean()
      # average (anchor, negative) distance
      d_an = dist_an.data.mean()

      prec_meter.update(prec)
      sm_meter.update(sm)
      dist_ap_meter.update(d_ap)
      dist_an_meter.update(d_an)
      loss_meter.update(to_scalar(loss))

      if step % cfg.steps_per_log == 0:
        time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
          step, ep + 1, time.time() - step_st, )

        tri_log = (', prec {:.2%}, sm {:.2%}, '
                   'd_ap {:.4f}, d_an {:.4f}, '
                   'loss {:.4f}'.format(
          prec_meter.val, sm_meter.val,
          dist_ap_meter.val, dist_an_meter.val,
          loss_meter.val, ))

        log = time_log + tri_log
        print(log)

    #############
    # Epoch Log #
    #############

    time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st)

    tri_log = (', prec {:.2%}, sm {:.2%}, '
               'd_ap {:.4f}, d_an {:.4f}, '
               'loss {:.4f}'.format(
      prec_meter.avg, sm_meter.avg,
      dist_ap_meter.avg, dist_an_meter.avg,
      loss_meter.avg, ))

    log = time_log + tri_log
    print(log)

    ##########################
    # Test on Validation Set #
    ##########################

    mAP, Rank1 = 0, 0
    if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
      mAP, Rank1 = validate()

    # Log to TensorBoard

    if cfg.log_to_file:
      if writer is None:
        writer = SummaryWriter(log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
      writer.add_scalars(
        'val scores',
        dict(mAP=mAP,
             Rank1=Rank1),
        ep)
      writer.add_scalars(
        'loss',
        dict(loss=loss_meter.avg, ),
        ep)
      writer.add_scalars(
        'precision',
        dict(precision=prec_meter.avg, ),
        ep)
      writer.add_scalars(
        'satisfy_margin',
        dict(satisfy_margin=sm_meter.avg, ),
        ep)
      writer.add_scalars(
        'average_distance',
        dict(dist_ap=dist_ap_meter.avg,
             dist_an=dist_an_meter.avg, ),
        ep)

    # save ckpt
    if cfg.log_to_file:
      save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

  ########
  # Test #
  ########

  test(load_model_weight=False)
コード例 #6
0
ファイル: ACNet_train_V1_nyuv2.py プロジェクト: jtpils/ACNet
def train():
    train_data = ACNet_data.SUNRGBD(transform=transforms.Compose([ACNet_data.scaleNorm(),
                                                                   ACNet_data.RandomScale((1.0, 1.4)),
                                                                   ACNet_data.RandomHSV((0.9, 1.1),
                                                                                         (0.9, 1.1),
                                                                                         (25, 25)),
                                                                   ACNet_data.RandomCrop(image_h, image_w),
                                                                   ACNet_data.RandomFlip(),
                                                                   ACNet_data.ToTensor(),
                                                                   ACNet_data.Normalize()]),
                                     phase_train=True,
                                     data_dir=args.data_dir)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=False)

    num_train = len(train_data)

    if args.last_ckpt:
        model = ACNet_models_V1.ACNet(num_class=40, pretrained=False)
    else:
        model = ACNet_models_V1.ACNet(num_class=40, pretrained=True)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    CEL_weighted = utils.CrossEntropyLoss2d(weight=nyuv2_frq)
    model.train()
    model.to(device)
    CEL_weighted.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum, weight_decay=args.weight_decay)

    global_step = 0

    if args.last_ckpt:
        global_step, args.start_epoch = load_ckpt(model, optimizer, args.last_ckpt, device)

    lr_decay_lambda = lambda epoch: args.lr_decay_rate ** (epoch // args.lr_epoch_per_decay)
    scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda)

    writer = SummaryWriter(args.summary_dir)

    for epoch in range(int(args.start_epoch), args.epochs):

        scheduler.step(epoch)
        local_count = 0
        last_count = 0
        end_time = time.time()
        if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch:
            save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch,
                      local_count, num_train)

        for batch_idx, sample in enumerate(train_loader):

            image = sample['image'].to(device)
            depth = sample['depth'].to(device)
            target_scales = [sample[s].to(device) for s in ['label', 'label2', 'label3', 'label4', 'label5']]
            optimizer.zero_grad()
            pred_scales = model(image, depth, args.checkpoint)
            loss = CEL_weighted(pred_scales, target_scales)
            loss.backward()
            optimizer.step()
            local_count += image.data.shape[0]
            global_step += 1
            if global_step % args.print_freq == 0 or global_step == 1:

                time_inter = time.time() - end_time
                count_inter = local_count - last_count
                print_log(global_step, epoch, local_count, count_inter,
                          num_train, loss, time_inter)
                end_time = time.time()

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step, bins='doane')
                grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True)
                writer.add_image('image', grid_image, global_step)
                grid_image = make_grid(depth[:3].clone().cpu().data, 3, normalize=True)
                writer.add_image('depth', grid_image, global_step)
                grid_image = make_grid(utils.color_label(torch.max(pred_scales[0][:3], 1)[1] + 1), 3, normalize=False,
                                       range=(0, 255))
                writer.add_image('Predicted label', grid_image, global_step)
                grid_image = make_grid(utils.color_label(target_scales[0][:3]), 3, normalize=False, range=(0, 255))
                writer.add_image('Groundtruth label', grid_image, global_step)
                writer.add_scalar('CrossEntropyLoss', loss.data, global_step=global_step)
                writer.add_scalar('Learning rate', scheduler.get_lr()[0], global_step=global_step)
                last_count = local_count

    save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs,
              0, num_train)

    print("Training completed ")
コード例 #7
0
    def validate_cls(self):
        # self.model.eval()
        self.model.train()

        with torch.no_grad():
            """
            Difference: abnormal dataloader and abnormal_list
            """
            if self.args.data_modality == 'fundus':
                myopia_gt_list, myopia_pred_list = self.forward_cls_dataloader(
                    loader=self.myopia_fundus_loader, is_disease=True)

                amd_gt_list, amd_pred_list = self.forward_cls_dataloader(
                    loader=self.amd_fundus_loader, is_disease=True)
                glaucoma_gt_list, glaucoma_pred_list = self.forward_cls_dataloader(
                    loader=self.glaucoma_fundus_loader, is_disease=True)
                dr_gt_list, dr_pred_list = self.forward_cls_dataloader(
                    loader=self.dr_fundus_loader, is_disease=True)
            else:
                abnormal_gt_list, abnormal_pred_list = self.forward_cls_dataloader(
                    loader=self.oct_abnormal_loader, is_disease=True)

            _, normal_train_pred_list = self.forward_cls_dataloader(
                loader=self.train_loader, is_disease=False)
            normal_gt_list, normal_pred_list = self.forward_cls_dataloader(
                loader=self.normal_test_loader, is_disease=False)
            """
            computer metrics
            """
            # Difference: total_true_list and total_pred_list
            if self.args.data_modality == 'fundus':
                # test metrics for myopia
                m_true_list = myopia_gt_list + normal_gt_list
                m_pred_list = myopia_pred_list + normal_pred_list
                # test metrics for amd
                a_true_list = amd_gt_list + normal_gt_list
                a_pred_list = amd_pred_list + normal_pred_list
                # test metrics for glaucoma
                g_true_list = glaucoma_gt_list + normal_gt_list
                g_pred_list = glaucoma_pred_list + normal_pred_list
                # test metrics for amd
                d_true_list = dr_gt_list + normal_gt_list
                d_pred_list = dr_pred_list + normal_pred_list
                # total
                total_true_list = a_true_list + myopia_gt_list + glaucoma_gt_list + dr_gt_list
                total_pred_list = a_pred_list + myopia_pred_list + glaucoma_pred_list + dr_pred_list

                # fpr, tpr, thresholds = metrics.roc_curve()
                myopia_auc = metrics.roc_auc_score(np.array(m_true_list),
                                                   np.array(m_pred_list))
                amd_auc = metrics.roc_auc_score(np.array(a_true_list),
                                                np.array(a_pred_list))
                glaucoma_auc = metrics.roc_auc_score(np.array(g_true_list),
                                                     np.array(g_pred_list))
                dr_auc = metrics.roc_auc_score(np.array(d_true_list),
                                               np.array(d_pred_list))
            else:
                total_true_list = abnormal_gt_list + normal_gt_list
                total_pred_list = abnormal_pred_list + normal_pred_list

            # get roc curve and compute the auc
            fpr, tpr, thresholds = metrics.roc_curve(np.array(total_true_list),
                                                     np.array(total_pred_list))
            total_auc = metrics.auc(fpr, tpr)
            """
            compute thereshold, and then compute the accuracy
            """
            percentage = 0.75
            _threshold_for_acc = sorted(normal_train_pred_list)[int(
                len(normal_train_pred_list) * percentage)]
            normal_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                    for i in normal_pred_list]
            amd_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                 for i in amd_pred_list]
            myopia_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                    for i in myopia_pred_list]
            glaucoma_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                      for i in glaucoma_pred_list]
            dr_cls_pred_list = [(0 if i < _threshold_for_acc else 1)
                                for i in dr_pred_list]

            # acc, sensitivity and specifity
            def calcu_cls_acc(pred_list, gt_list):
                cls_pred_list = normal_cls_pred_list + pred_list
                gt_list = normal_gt_list + gt_list
                acc = metrics.accuracy_score(y_true=gt_list,
                                             y_pred=cls_pred_list)
                tn, fp, fn, tp = metrics.confusion_matrix(
                    y_true=gt_list, y_pred=cls_pred_list).ravel()
                sen = tp / (tp + fn + 1e-7)
                spe = tn / (tn + fp + 1e-7)
                return acc, sen, spe

            total_acc, total_sen, total_spe = calcu_cls_acc(
                amd_cls_pred_list + myopia_cls_pred_list,
                amd_gt_list + myopia_gt_list)
            amd_acc, amd_sen, amd_spe = calcu_cls_acc(amd_cls_pred_list,
                                                      amd_gt_list)
            myopia_acc, myopia_sen, myopia_spe = calcu_cls_acc(
                myopia_cls_pred_list, myopia_gt_list)

            # update
            if self.args.data_modality:
                self.myopia_auc_last20.update(myopia_auc)
                self.amd_auc_last20.update(amd_auc)

            self.total_auc_last20.update(total_auc)
            mean, deviation = self.total_auc_top10.top_update_calc(total_auc)

            self.is_best = total_auc > self.best_auc
            self.best_auc = max(total_auc, self.best_auc)
            """
            plot metrics curve
            """
            # ROC curve
            self.vis.draw_roc(fpr, tpr)
            # total auc, primary metrics
            self.vis.plot_single_win(dict(value=total_auc,
                                          best=self.best_auc,
                                          last_avg=self.total_auc_last20.avg,
                                          last_std=self.total_auc_last20.std,
                                          top_avg=mean,
                                          top_dev=deviation),
                                     win='total_auc')

            self.vis.plot_single_win(dict(total_acc=total_acc,
                                          total_sen=total_sen,
                                          total_spe=total_spe,
                                          amd_acc=amd_acc,
                                          amd_sen=amd_sen,
                                          amd_spe=amd_spe,
                                          myopia_acc=myopia_acc,
                                          myopia_sen=myopia_sen,
                                          myopia_spe=myopia_spe),
                                     win='accuracy')

            # Difference
            if self.args.data_modality == 'fundus':
                self.vis.plot_single_win(dict(
                    value=amd_auc,
                    last_avg=self.amd_auc_last20.avg,
                    last_std=self.amd_auc_last20.std),
                                         win='amd_auc')
                self.vis.plot_single_win(dict(
                    value=myopia_auc,
                    last_avg=self.myopia_auc_last20.avg,
                    last_std=self.myopia_auc_last20.std),
                                         win='myopia_auc')

                metrics_str = 'best_auc = {:.4f},' \
                              'total_avg = {:.4f}, total_std = {:.4f}, ' \
                              'total_top_avg = {:.4f}, total_top_dev = {:.4f}, ' \
                              'amd_avg = {:.4f}, amd_std = {:.4f}, ' \
                              'myopia_avg = {:.4f}, myopia_std ={:.4f}'.format(self.best_auc,
                                       self.total_auc_last20.avg, self.total_auc_last20.std,
                                       mean, deviation,
                                       self.amd_auc_last20.avg, self.amd_auc_last20.std,
                                       self.myopia_auc_last20.avg, self.myopia_auc_last20.std)
                metrics_acc_str = '\n total_acc = {:.4f}, total_sen = {:.4f}, total_spe = {:.4f}, ' \
                                  'amd_acc = {:.4f}, amd_sen = {:.4f}, amd_spe = {:.4f}, ' \
                                  'myopia_acc = {:.4f}, myopia_sen = {:.4f}, myopia_spe = {:.4f}'\
                    .format(total_acc, total_sen, total_spe, amd_acc, amd_sen,
                            amd_spe, myopia_acc, myopia_sen, myopia_spe)

            else:
                metrics_str = 'best_auc = {:.4f},' \
                              'total_avg = {:.4f}, total_std = {:.4f}, ' \
                              'total_top_avg = {:.4f}, total_top_dev = {:.4f}'.format(self.best_auc,
                                      self.total_auc_last20.avg,
                                      self.total_auc_last20.std,
                                      mean, deviation)
                metrics_acc_str = '\n None'

            self.vis.text(metrics_str + metrics_acc_str)

        save_ckpt(version=self.args.version,
                  state={
                      'epoch': self.epoch,
                      'state_dict_G': self.model.model_G2.state_dict(),
                      'state_dict_D': self.model.model_D.state_dict(),
                  },
                  epoch=self.epoch,
                  is_best=self.is_best,
                  args=self.args)

        print('\n Save ckpt successfully!')
        print('\n', metrics_str + metrics_acc_str)
コード例 #8
0
ファイル: seg_train_val_rdbi.py プロジェクト: himmat-l/MAENet
def train():
    # 记录数据在tensorboard中显示
    writer_loss = SummaryWriter(os.path.join(args.summary_dir, 'loss'))
    # writer_loss1 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss1'))
    # writer_loss2 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss2'))
    # writer_loss3 = SummaryWriter(os.path.join(args.summary_dir, 'loss', 'loss3'))
    writer_acc = SummaryWriter(os.path.join(args.summary_dir, 'macc'))

    # 准备数据集
    train_data = data_eval.ReadData(transform=transforms.Compose([
        data_eval.scaleNorm(),
        data_eval.RandomScale((1.0, 1.4)),
        data_eval.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)),
        data_eval.RandomCrop(image_h, image_w),
        data_eval.RandomFlip(),
        data_eval.ToTensor(),
        data_eval.Normalize()
    ]),
                                    data_dir=args.train_data_dir)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=False,
                              drop_last=True)
    val_data = data_eval.ReadData(transform=transforms.Compose([
        data_eval.scaleNorm(),
        data_eval.RandomScale((1.0, 1.4)),
        data_eval.RandomCrop(image_h, image_w),
        data_eval.ToTensor(),
        data_eval.Normalize()
    ]),
                                  data_dir=args.val_data_dir)
    val_loader = DataLoader(val_data,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.workers,
                            pin_memory=False,
                            drop_last=True)
    num_train = len(train_data)
    # num_val = len(val_data)

    # build model
    if args.last_ckpt:
        model = MultiTaskCNN_Atten(38,
                                   depth_channel=1,
                                   pretrained=False,
                                   arch='resnet50',
                                   use_aspp=True)
    else:
        model = MultiTaskCNN_Atten(38,
                                   depth_channel=1,
                                   pretrained=True,
                                   arch='resnet50',
                                   use_aspp=True)

    # build optimizer
    if args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), args.lr)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None
    global_step = 0
    max_miou_val = 0
    loss_count = 0
    # 如果有模型的训练权重,则获取global_step,start_epoch
    if args.last_ckpt:
        global_step, args.start_epoch = load_ckpt(model, optimizer,
                                                  args.last_ckpt, device)
    # if torch.cuda.device_count() > 1 and args.cuda and torch.cuda.is_available():
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     model = torch.nn.DataParallel(model).to(device)
    model = model.to(device)
    model.train()
    # cal_param(model, data)
    loss_func = nn.CrossEntropyLoss()
    for epoch in range(int(args.start_epoch), args.epochs):
        torch.cuda.empty_cache()
        # if epoch <= freeze_epoch:
        #     for layer in [model.conv1, model.maxpool,model.layer1, model.layer2, model.layer3, model.layer4]:
        #         for param in layer.parameters():
        #             param.requires_grad = False
        tq = tqdm(total=len(train_loader) * args.batch_size)
        if loss_count >= 10:
            args.lr = 0.5 * args.lr
            loss_count = 0
        lr = poly_lr_scheduler(optimizer,
                               args.lr,
                               iter=epoch,
                               max_iter=args.epochs)
        optimizer.param_groups[0]['lr'] = lr
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.5)
        tq.set_description('epoch %d, lr %f' % (epoch, args.lr))
        loss_record = []
        # loss1_record = []
        # loss2_record = []
        # loss3_record = []
        local_count = 0
        # print('1')
        for batch_idx, data in enumerate(train_loader):
            image = data['image'].to(device)
            depth = data['depth'].to(device)
            label = data['label'].long().to(device)
            # print('label', label.shape)
            output, output_sup1, output_sup2 = model(image, depth)
            loss1 = loss_func(output, label)
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)
            loss = loss1 + loss2 + loss3
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1
            local_count += image.data.shape[0]
            # writer_loss.add_scalar('loss_step', loss, global_step)
            # writer_loss1.add_scalar('loss1_step', loss1, global_step)
            # writer_loss2.add_scalar('loss2_step', loss2, global_step)
            # writer_loss3.add_scalar('loss3_step', loss3, global_step)
            loss_record.append(loss.item())
            # loss1_record.append(loss1.item())
            # loss2_record.append(loss2.item())
            # loss3_record.append(loss3.item())
            if global_step % args.print_freq == 0 or global_step == 1:
                for name, param in model.named_parameters():
                    writer_loss.add_histogram(name,
                                              param.clone().cpu().data.numpy(),
                                              global_step,
                                              bins='doane')
                writer_loss.add_graph(model, [image, depth])
                grid_image1 = make_grid(image[:3].clone().cpu().data,
                                        3,
                                        normalize=True)
                writer_loss.add_image('image', grid_image1, global_step)
                grid_image2 = make_grid(depth[:3].clone().cpu().data,
                                        3,
                                        normalize=True)
                writer_loss.add_image('depth', grid_image2, global_step)
                grid_image3 = make_grid(utils.color_label(
                    torch.max(output[:3], 1)[1]),
                                        3,
                                        normalize=False,
                                        range=(0, 255))
                writer_loss.add_image('Predicted label', grid_image3,
                                      global_step)
                grid_image4 = make_grid(utils.color_label(label[:3]),
                                        3,
                                        normalize=False,
                                        range=(0, 255))
                writer_loss.add_image('Groundtruth label', grid_image4,
                                      global_step)

        tq.close()
        loss_train_mean = np.mean(loss_record)
        with open(log_file, 'a') as f:
            f.write(str(epoch) + '\t' + str(loss_train_mean))
        # loss1_train_mean = np.mean(loss1_record)
        # loss2_train_mean = np.mean(loss2_record)
        # loss3_train_mean = np.mean(loss3_record)
        writer_loss.add_scalar('epoch/loss_epoch_train',
                               float(loss_train_mean), epoch)
        # writer_loss1.add_scalar('epoch/sub_loss_epoch_train', float(loss1_train_mean), epoch)
        # writer_loss2.add_scalar('epoch/sub_loss_epoch_train', float(loss2_train_mean), epoch)
        # writer_loss3.add_scalar('epoch/sub_loss_epoch_train', float(loss3_train_mean), epoch)
        print('loss for train : %f' % loss_train_mean)
        print('----validation starting----')
        # tq_val = tqdm(total=len(val_loader) * args.batch_size)
        # tq_val.set_description('epoch %d' % epoch)
        model.eval()

        val_total_time = 0
        with torch.no_grad():
            sys.stdout.flush()
            tbar = tqdm(val_loader)
            acc_meter = AverageMeter()
            intersection_meter = AverageMeter()
            union_meter = AverageMeter()
            a_meter = AverageMeter()
            b_meter = AverageMeter()
            for batch_idx, sample in enumerate(tbar):

                # origin_image = sample['origin_image'].numpy()
                # origin_depth = sample['origin_depth'].numpy()
                image_val = sample['image'].to(device)
                depth_val = sample['depth'].to(device)
                label_val = sample['label'].numpy()

                with torch.no_grad():
                    start = time.time()
                    pred = model(image_val, depth_val)
                    end = time.time()
                    duration = end - start
                    val_total_time += duration
                # tq_val.set_postfix(fps ='%.4f' % (args.batch_size / (end - start)))
                print_str = 'Test step [{}/{}].'.format(
                    batch_idx + 1, len(val_loader))
                tbar.set_description(print_str)

                output_val = torch.max(pred, 1)[1]
                output_val = output_val.squeeze(0).cpu().numpy()

                acc, pix = accuracy(output_val, label_val)
                intersection, union = intersectionAndUnion(
                    output_val, label_val, args.num_class)
                acc_meter.update(acc, pix)
                a_m, b_m = macc(output_val, label_val, args.num_class)
                intersection_meter.update(intersection)
                union_meter.update(union)
                a_meter.update(a_m)
                b_meter.update(b_m)
        fps = len(val_loader) / val_total_time
        print('fps = %.4f' % fps)
        tbar.close()
        mAcc = (a_meter.average() / (b_meter.average() + 1e-10))
        with open(log_file, 'a') as f:
            f.write('                    ' + str(mAcc.mean()) + '\n')
        iou = intersection_meter.sum / (union_meter.sum + 1e-10)
        writer_acc.add_scalar('epoch/Acc_epoch_train', mAcc.mean(), epoch)
        print('----validation finished----')
        model.train()
        # # 每隔save_epoch_freq个epoch就保存一次权重
        if epoch != args.start_epoch:
            if iou.mean() >= max_miou_val:
                print('mIoU:', iou.mean())
                if not os.path.isdir(args.ckpt_dir):
                    os.mkdir(args.ckpt_dir)
                save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch,
                          local_count, num_train)
                max_miou_val = iou.mean()
                # max_macc_val = mAcc.mean()
            else:
                loss_count += 1
        torch.cuda.empty_cache()
コード例 #9
0
def main():
    # Set logger to record information.
    logger = Logger(cfg)
    logger.log_info(cfg)
    metrics_logger = Metrics()
    utils.pack_code(cfg, logger=logger)

    # Build model.
    model = model_builder.build_model(cfg=cfg, logger=logger)

    # Read checkpoint.
    ckpt = torch.load(cfg.MODEL.PATH2CKPT) if cfg.GENERAL.RESUME else {}

    if cfg.GENERAL.RESUME:
        model.load_state_dict(ckpt["model"])
    resume_epoch = ckpt["epoch"] if cfg.GENERAL.RESUME else 0
    optimizer = ckpt[
        "optimizer"] if cfg.GENERAL.RESUME else optimizer_helper.build_optimizer(
            cfg=cfg, model=model)
    # lr_scheduler = ckpt["lr_scheduler"] if cfg.GENERAL.RESUME else lr_scheduler_helper.build_scheduler(cfg=cfg, optimizer=optimizer)
    lr_scheduler = lr_scheduler_helper.build_scheduler(cfg=cfg,
                                                       optimizer=optimizer)
    lr_scheduler.sychronize(resume_epoch)
    loss_fn = ckpt[
        "loss_fn"] if cfg.GENERAL.RESUME else loss_fn_helper.build_loss_fn(
            cfg=cfg)

    # Set device.
    model, device = utils.set_device(model, cfg.GENERAL.GPU)

    # Prepare dataset.
    if cfg.GENERAL.TRAIN:
        try:
            train_data_loader = data_loader.build_data_loader(
                cfg, cfg.DATA.DATASET, "train")
        except:
            logger.log_info("Cannot build train dataset.")
    if cfg.GENERAL.VALID:
        try:
            valid_data_loader = data_loader.build_data_loader(
                cfg, cfg.DATA.DATASET, "valid")
        except:
            logger.log_info("Cannot build valid dataset.")
    if cfg.GENERAL.TEST:
        try:
            test_data_loader = data_loader.build_data_loader(
                cfg, cfg.DATA.DATASET, "test")
        except:
            logger.log_info("Cannot build test dataset.")

    # Train, evaluate model and save checkpoint.
    for epoch in range(cfg.TRAIN.MAX_EPOCH):
        if resume_epoch >= epoch:
            continue

        try:
            train_one_epoch(
                epoch=epoch,
                cfg=cfg,
                model=model,
                data_loader=train_data_loader,
                device=device,
                loss_fn=loss_fn,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,
                metrics_logger=metrics_logger,
                logger=logger,
            )
        except:
            logger.log_info("Failed to train model.")

        optimizer.zero_grad()
        with torch.no_grad():
            utils.save_ckpt(
                path2file=os.path.join(
                    cfg.MODEL.CKPT_DIR,
                    cfg.GENERAL.ID + "_" + str(epoch).zfill(3) + ".pth"),
                logger=logger,
                model=model.state_dict(),
                epoch=epoch,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,  # NOTE Need attribdict>=0.0.5
                loss_fn=loss_fn,
                metrics=metrics_logger,
            )
        try:
            evaluate(
                epoch=epoch,
                cfg=cfg,
                model=model,
                data_loader=valid_data_loader,
                device=device,
                loss_fn=loss_fn,
                metrics_logger=metrics_logger,
                phase="valid",
                logger=logger,
                save=cfg.SAVE.SAVE,
            )
        except:
            logger.log_info("Failed to evaluate model.")

        with torch.no_grad():
            utils.save_ckpt(
                path2file=os.path.join(
                    cfg.MODEL.CKPT_DIR,
                    cfg.GENERAL.ID + "_" + str(epoch).zfill(3) + ".pth"),
                logger=logger,
                model=model.state_dict(),
                epoch=epoch,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,  # NOTE Need attribdict>=0.0.5
                loss_fn=loss_fn,
                metrics=metrics_logger,
            )

    # If test set has target images, evaluate and save them, otherwise just try to generate output images.
    if cfg.DATA.DATASET == "DualPixelNTIRE2021":
        try:
            generate(
                cfg=cfg,
                model=model,
                data_loader=valid_data_loader,
                device=device,
                phase="valid",
                logger=logger,
            )
        except:
            logger.log_info(
                "Failed to generate output images of valid set of NTIRE2021.")
    try:
        evaluate(
            epoch=epoch,
            cfg=cfg,
            model=model,
            data_loader=test_data_loader,
            device=device,
            loss_fn=loss_fn,
            metrics_logger=metrics_logger,
            phase="test",
            logger=logger,
            save=True,
        )
    except:
        logger.log_info("Failed to test model, try to generate images.")
        try:
            generate(
                cfg=cfg,
                model=model,
                data_loader=test_data_loader,
                device=device,
                phase="test",
                logger=logger,
            )
        except:
            logger.log_info("Cannot generate output images of test set.")
    return None
コード例 #10
0
    ##############
    # add tensorboard log for this epoch
    writer.add_scalar('train/total_avgloss_epoch', loss_meter.avg, epoch + 1)
    # print the log for this epoch
    log = 'Ep{}, {:.2f}s, loss {:.4f}'.format(epoch + 1,
                                              time.time() - ep_st,
                                              loss_meter.avg)
    print(log)

    # Average epoch time updates
    time_meter.update(time.time() - ep_st)
    # save model weights for every "epochs_per_save" epochs
    if (epoch + 1) % cfg.epochs_per_save == 0 or epoch + 1 == cfg.total_epochs:
        ckpt_file = os.path.join(cfg.exp_dir, 'model',
                                 'ckpt_epoch%d.pth' % (epoch + 1))
        save_ckpt(modules_optims, epoch + 1, 0, ckpt_file)

    ##########################
    # test on val set #
    ##########################
    if (epoch + 1) % cfg.epochs_per_val == 0 or (epoch +
                                                 1) == cfg.total_epochs:
        print('test on valset')
        # set the model to "eval" for testing
        may_set_mode(modules_optims, 'eval')
        testloss_meter = AverageMeter()
        pred_list = [
        ]  # a list for storing predictions for the whole validation set; initialized with an empty list.
        target_list = [
        ]  # a list for storing targets for the whole validation set; initialized with an empty list
        # iterate over all batches in validation set
コード例 #11
0
def train_multitask(model, dataloaders):
    train_dataloader, dev_dataloader, test_dataloader = dataloaders
    if constant.optim == 'Adam':
        opt = torch.optim.Adam(model.parameters(), lr=constant.lr)
    elif constant.optim == 'SGD':
        opt = torch.optim.SGD(model.parameters(), lr=constant.lr)
    else:
        print("Optim is not defined")
        exit(1)

    start_epoch = 1
    if constant.restore:
        model, opt, start_epoch = load_ckpt(model, opt, constant.restore_path)

    if constant.USE_CUDA:
        model.cuda()
        if constant.embeddings_cpu:
            model.encoder.embedding.cpu()
                
    best_gen = 10000
    best_emo = 0
    best_path = ''
    patience = 3
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        opt, 'min', factor=0.5, patience=0, min_lr=1e-6)

    try:
        for e in range(start_epoch, constant.epochs):
            model.train()
            gen_loss_log = []
            emo_loss_log = []
            cyc_loss_log = []
            ppl_log = []
            emo_f1_log = []
            cyc_f1_log = []

            if constant.grid_search:
                pbar = enumerate(train_dataloader)
            else:
                pbar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
            
            for b, (dialogs, lens, targets, _, _, emotions, sentiments, _, _) in pbar:
                if constant.use_sentiment:
                    emotions = sentiments
                if len(train_dataloader) % (b+1) == 10:
                    torch.cuda.empty_cache()
                opt.zero_grad()
                try:
                    # batch_size, max_target_len = targets.shape
                    emo_logits, _ = model(dialogs, lens, targets, emotions=emotions)

                    if emo_logits is not None:
                        emo_loss_log.append(model.loss['emo'].item())
                        if constant.use_emotion:
                            preds = torch.argmax(emo_logits, dim=1)
                        elif constant.use_sentiment:
                            preds = torch.sigmoid(emo_logits.squeeze()) > 0.5
                        emo_f1 = f1_score(emotions.cpu().numpy(), preds.detach().cpu().numpy(), average='weighted')
                        emo_f1_log.append(emo_f1)
                    else:
                        emo_loss_log = 100
                        emo_f1_log = 0

                    model.backward()
                    opt.step()

                    ## logging
                    gen_loss_log.append(model.loss['gen'].item())
                    ppl_log.append(math.exp(gen_loss_log[-1]))
                    if not constant.grid_search:
                        pbar.set_description("(Epoch {}) L_G:{:.4f} PPL:{:.1f} L_E:{:.4f} F1_E:{:.4f}".format(
                            e, np.mean(gen_loss_log), np.mean(ppl_log), np.mean(emo_loss_log), np.mean(emo_f1_log)))
                except RuntimeError as err:
                    if 'out of memory' in str(err):
                        print('| WARNING: ran out of memory, skipping batch')
                        torch.cuda.empty_cache()
                    else:
                        raise err
            ## LOG
            (gen_loss, ppl), (emo_f1) = eval_multitask(model, dev_dataloader, bleu=False)
            
            print("(Epoch {}) DEV GEN LOSS:{:.4f} DEV PPL:{:.1f} DEV EMO F1:{:.4f}".format(e, gen_loss, ppl, emo_f1))

            scheduler.step(gen_loss)
            if gen_loss < best_gen:
                best_gen = gen_loss
                best_emo = emo_f1
                # save best model
                path = 'trained/data-{}.task-multiseq.lr-{}.emb-{}.D-{}.H-{}.attn-{}.bi-{}.parse-{}.gen_loss-{}.emo_f1-{}' # lr.embedding.D.H.attn.bi.parse.metric
                path = path.format(constant.data, constant.lr, constant.embedding, constant.D, constant.H, constant.attn, constant.bi, constant.parse, best_gen, best_emo)
                if constant.use_sentiment:
                    path += '.sentiment'
                if constant.grid_search:
                    path += '.grid'
                best_path = save_model(model, 'loss', best_gen, path)
                patience = 3
            else:
                patience -= 1
            if patience == 0: break
            if best_gen == 0.0: break

    except KeyboardInterrupt:
        if not constant.grid_search:
            print("KEYBOARD INTERRUPT: Save CKPT and Eval")
            save = True if input('Save ckpt? (y/n)\t') in ['y', 'Y', 'yes', 'Yes'] else False
            if save:
                save_path = save_ckpt(model, opt, e)
                print("Saved CKPT path: ", save_path)
            # ask if eval
            do_eval = True if input('Proceed with eval? (y/n)\t') in ['y', 'Y', 'yes', 'Yes'] else False
            if do_eval:
                (dev_loss, dev_ppl, dev_bleu, dev_bleus), (emo_f1) = eval_multitask(model, dev_dataloader, bleu=True, beam=constant.beam)
                print("DEV LOSS: {:.4f}, DEV PPL: {:.1f}, DEV BLEU: {:.4f}".format(dev_loss, dev_ppl, dev_bleu))
                print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".format(dev_bleus[0], dev_bleus[1], dev_bleus[2], dev_bleus[3]))
                print("DEV EMO F1: {:.4f}".format(emo_f1))
        exit(1)


    # load and report best model on test
    torch.cuda.empty_cache()
    model = load_model(model, best_path)
    if constant.USE_CUDA:
        model.cuda()

    (dev_loss, dev_ppl, dev_bleu, dev_bleus), (dev_emo_f1) = eval_multitask(model, dev_dataloader, bleu=True, beam=constant.beam)
    (test_loss, test_ppl, test_bleu, test_bleus), (test_emo_f1) = eval_multitask(model, test_dataloader, bleu=True, beam=constant.beam)

    print("BEST DEV LOSS: {:.4f}, DEV PPL: {:.1f}, DEV BLEU: {:.4f}".format(dev_loss, dev_ppl, dev_bleu))
    print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".format(dev_bleus[0], dev_bleus[1], dev_bleus[2], dev_bleus[3]))
    print("DEV EMO F1: {:.4f}".format(dev_emo_f1))

    print("BEST TEST LOSS: {:.4f}, TEST PPL: {:.1f}, TEST BLEU: {:.4f}".format(test_loss, test_ppl, test_bleu))
    print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".format(test_bleus[0], test_bleus[1], test_bleus[2], test_bleus[3]))
    print("TEST EMO F1: {:.4f}".format(test_emo_f1))
コード例 #12
0
ファイル: main.py プロジェクト: ClaraBing/PyTorch-BiGAN
                        type=str,
                        default='',
                        help="Suffix for save_path")
    parser.add_argument('--pretrained-path', type=str, default='')
    #parsing arguments.
    args = parser.parse_args()
    args.save_path = 'BiGAN_{}_lr{}_wd1e-6_bt{}_dim{}_k{}_W{}_{}{}epoch{}{}{}{}.pt'.format(
        args.data, args.lr_adam, args.batch_size, args.latent_dim,
        args.first_filter_size, 1 if args.wasserstein else 0,
        'l2{}_'.format(args.l2_loss_weight) if args.use_l2_loss else '',
        'zRelu_' if args.use_relu_z else '', args.num_epochs,
        '_normed' if args.normalize_data else '',
        '_freezeGD' if args.freeze_GD else '',
        '_' + args.save_token if args.save_token else '')
    if USE_WANDB:
        wandb.init(project='visualize', name=args.save_path, config=args)
    args.save_path = os.path.join('ckpts', args.save_path)

    #check if cuda is available.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if args.data == 'cifar':
        data = get_cifar10(args)
    elif args.data == 'mnist':
        data = get_mnist(args)

    bigan = TrainerBiGAN(args, data, device)
    bigan.train()
    print('Finished training.')
    save_ckpt(bigan, args.save_path)
コード例 #13
0
def train(cfg):
    '''
    This is the main loop for training
    Loads the dataset, model, and other things
    '''
    print json.dumps(cfg, sort_keys=True, indent=4)

    use_cuda = cfg['use-cuda']

    _, _, train_dl, val_dl = utils.get_data_loaders(cfg)

    model = utils.get_model(cfg)
    if use_cuda:
        model = model.cuda()
    model = utils.init_weights(model, cfg)

    # Get pretrained models, optimizers and loss functions
    optim = utils.get_optimizers(model, cfg)
    model, optim, metadata = utils.load_ckpt(model, optim, cfg)
    loss_fn = utils.get_losses(cfg)

    # Set up random seeds
    seed = np.random.randint(2**32)
    ckpt = 0
    if metadata is not None:
        seed = metadata['seed']
        ckpt = metadata['ckpt']

    # Get schedulers after getting checkpoints
    scheduler = utils.get_schedulers(optim, cfg, ckpt)
    # Print optimizer state
    print optim

    # Get loss file handle to dump logs to
    if not os.path.exists(cfg['save-path']):
        os.makedirs(cfg['save-path'])
    lossesfile = open(os.path.join(cfg['save-path'], 'losses.txt'), 'a+')

    # Random seed according to what the saved model is
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Run training loop
    num_epochs = cfg['train']['num-epochs']
    for epoch in range(num_epochs):
        # Run the main training loop
        model.train()
        for data in train_dl:
            # zero out the grads
            optim.zero_grad()

            # Change to required device
            for key, value in data.items():
                data[key] = Variable(value)
                if use_cuda:
                    data[key] = data[key].cuda()

            # Get all outputs
            outputs = model(data)
            loss_val = loss_fn(outputs, data, cfg)

            # print it
            print('Epoch: {}, step: {}, loss: {}'.format(
                epoch, ckpt,
                loss_val.data.cpu().numpy()))

            # Log into the file after some epochs
            if ckpt % cfg['train']['step-log'] == 0:
                lossesfile.write('Epoch: {}, step: {}, loss: {}\n'.format(
                    epoch, ckpt,
                    loss_val.data.cpu().numpy()))

            # Backward
            loss_val.backward()
            optim.step()

            # Update schedulers
            scheduler.step()

            # Peek into the validation set
            ckpt += 1
            if ckpt % cfg['peek-validation'] == 0:
                model.eval()
                with torch.no_grad():
                    for val_data in val_dl:
                        # Change to required device
                        for key, value in val_data.items():
                            val_data[key] = Variable(value)
                            if use_cuda:
                                val_data[key] = val_data[key].cuda()

                        # Get all outputs
                        outputs = model(val_data)
                        loss_val = loss_fn(outputs, val_data, cfg)

                        print 'Validation loss: {}'.format(
                            loss_val.data.cpu().numpy())

                        lossesfile.write('Validation loss: {}\n'.format(\
                            loss_val.data.cpu().numpy()))
                        utils.save_images(val_data, outputs, cfg, ckpt)
                        break
                model.train()
            # Save checkpoint
            utils.save_ckpt((model, optim), cfg, ckpt, seed)

    lossesfile.close()
コード例 #14
0
def train():
    # 记录数据在tensorboard中显示
    writer = SummaryWriter(args.summary_dir)

    # 准备数据集
    train_data = data_eval.ReadData(transform=transforms.Compose([
        data_eval.scaleNorm(),
        data_eval.RandomScale((1.0, 1.4)),
        data_eval.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)),
        data_eval.RandomCrop(image_h, image_w),
        data_eval.RandomFlip(),
        data_eval.ToTensor(),
        data_eval.Normalize()
    ]),
                                    data_dir=args.data_dir)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=False,
                              drop_last=True)

    num_train = len(train_data)
    # data = iter(train_loader).next()
    # print('data:', data['image'].shape)

    # build model
    if args.last_ckpt:
        model = MultiTaskCNN(38,
                             depth_channel=1,
                             pretrained=False,
                             arch='resnet18')
    else:
        model = MultiTaskCNN(38,
                             depth_channel=1,
                             pretrained=True,
                             arch='resnet18')
    model = model.to(device)

    # build optimizer
    if args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), args.lr)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None
    global_step = 0
    # 如果有模型的训练权重,则获取global_step,start_epoch
    if args.last_ckpt:
        global_step, args.start_epoch = load_ckpt(model, optimizer,
                                                  args.last_ckpt, device)

    # # 监测使用哪几个GPU
    # if torch.cuda.device_count() > 1:
    #     print("Let's use", torch.cuda.device_count(), "GPUs!")
    #     # nn.DataParallel(module, device_ids=None, output_device=None, dim=0):使用多块GPU进行计算
    #     model = nn.DataParallel(model)

    model.train()
    # cal_param(model, data)
    loss_func = nn.CrossEntropyLoss(weight=weight.float())

    for epoch in range(int(args.start_epoch), args.epochs):
        tq = tqdm(total=len(train_loader) * args.batch_size)
        lr = poly_lr_scheduler(optimizer,
                               args.lr,
                               iter=epoch,
                               max_iter=args.epochs)
        tq.set_description('epoch %d, lr %f' % (epoch, lr))
        loss_record = []
        local_count = 0
        # print('1')
        for batch_idx, data in enumerate(train_loader):
            # print(batch_idx)
            image = data['image'].to(device)
            depth = data['depth'].to(device)
            label = data['label'].long().to(device)
            # print('label', label.shape)
            output, output_sup1, output_sup2 = model(image, depth)
            loss1 = loss_func(output, label)
            loss2 = loss_func(output_sup1, label)
            loss3 = loss_func(output_sup2, label)
            loss = loss1 + loss2 + loss3
            tq.update(args.batch_size)
            tq.set_postfix(loss='%.6f' % loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1
            local_count += image.data.shape[0]
            writer.add_scalar('loss_step', loss, global_step)
            loss_record.append(loss.item())

            if global_step % args.print_freq == 0 or global_step == 1:
                for name, param in model.named_parameters():
                    writer.add_histogram(name,
                                         param.clone().cpu().data.numpy(),
                                         global_step,
                                         bins='doane')
                writer.add_graph(model, [image, depth])
                grid_image1 = make_grid(image[:3].clone().cpu().data,
                                        3,
                                        normalize=True)
                writer.add_image('image', grid_image1, global_step)
                grid_image2 = make_grid(depth[:3].clone().cpu().data,
                                        3,
                                        normalize=True)
                writer.add_image('depth', grid_image2, global_step)
                grid_image3 = make_grid(utils.color_label(
                    torch.max(output[:3], 1)[1]),
                                        3,
                                        normalize=False,
                                        range=(0, 255))
                writer.add_image('Predicted label', grid_image3, global_step)
                grid_image4 = make_grid(utils.color_label(label[:3]),
                                        3,
                                        normalize=False,
                                        range=(0, 255))
                writer.add_image('Groundtruth label', grid_image4, global_step)

        tq.close()
        loss_train_mean = np.mean(loss_record)
        writer.add_scalar('epoch/loss_epoch_train', float(loss_train_mean),
                          epoch)
        print('loss for train : %f' % loss_train_mean)
        # 每隔save_epoch_freq个epoch就保存一次权重
        if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch:
            if not os.path.isdir(args.ckpt_dir):
                os.mkdir(args.ckpt_dir)
            save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch,
                      local_count, num_train)
コード例 #15
0
def train_sentiment(model, dataloaders):
    """ 
    Training loop
    Inputs:
        model: the model to be trained
        dataloader: data loader
    Output:
        best_dev: best f1 score on dev data
        best_test: best f1 score on test data
    """
    train_dataloader, dev_dataloader, test_dataloader = dataloaders
    if (constant.USE_CUDA): model.cuda()
    criterion = nn.BCEWithLogitsLoss()

    if constant.use_bert:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        opt = BertAdam(optimizer_grouped_parameters,
                       lr=constant.lr,
                       warmup=0.01,
                       t_total=int(len(train_dataloader) * 5))
    else:
        opt = torch.optim.Adam(model.parameters(), lr=constant.lr)

    best_dev = 0
    best_test = 0
    patience = 3

    try:
        for e in range(constant.epochs):
            model.train()
            loss_log = []
            f1_log = []

            pbar = tqdm(enumerate(train_dataloader),
                        total=len(train_dataloader))

            if constant.grid_search:
                pbar = enumerate(train_dataloader)
            else:
                pbar = tqdm(enumerate(train_dataloader),
                            total=len(train_dataloader))

            for _, batch in pbar:
                if constant.use_bert:
                    input_ids, input_masks, segment_ids, sentiments = batch
                    logits = model(
                        (input_ids, segment_ids, input_masks)).squeeze()
                else:
                    sentences, lens, sentiments = batch
                    logits = model(sentences, lens).squeeze()

                if len(logits.shape) == 0:
                    logits = logits.unsqueeze(0)
                loss = criterion(logits, sentiments)
                loss.backward()
                opt.step()
                opt.zero_grad()

                ## logging
                loss_log.append(loss.item())
                preds = F.sigmoid(logits) > 0.5
                # preds = torch.argmax(logits, dim=1)
                f1 = f1_score(sentiments.cpu().numpy(),
                              preds.detach().cpu().numpy(),
                              average='weighted')
                f1_log.append(f1)
                if not constant.grid_search:
                    pbar.set_description(
                        "(Epoch {}) TRAIN F1:{:.4f} TRAIN LOSS:{:.4f}".format(
                            e + 1, np.mean(f1_log), np.mean(loss_log)))

            ## LOG
            f1 = eval_sentiment(model, dev_dataloader)
            testF1 = eval_sentiment(model, test_dataloader)
            print("(Epoch {}) DEV F1: {:.4f} TEST F1: {:.4f}".format(
                e + 1, f1, testF1))
            print("(Epoch {}) BEST DEV F1: {:.4f} BEST TEST F1: {:.4f}".format(
                e + 1, best_dev, best_test))
            if (f1 > best_dev):
                best_dev = f1
                best_test = testF1
                patience = 3
                path = 'trained/data-{}.task-sentiment.f1-{}'
                save_model(model, 'loss', best_dev,
                           path.format(constant.data, best_dev))
            else:
                patience -= 1
            if (patience == 0): break
            if (best_dev == 1.0): break

    except KeyboardInterrupt:
        if not constant.grid_search:
            print("KEYBOARD INTERRUPT: Save CKPT and Eval")
            save = True if input('Save ckpt? (y/n)\t') in [
                'y', 'Y', 'yes', 'Yes'
            ] else False
            if save:
                save_path = save_ckpt(model, opt, e)
                print("Saved CKPT path: ", save_path)
            print("BEST SCORES - DEV F1: {:.4f}, TEST F1: {:.4f}".format(
                best_dev, best_test))
        exit(1)

    print("BEST SCORES - DEV F1: {:.4f}, TEST F1: {:.4f}".format(
        best_dev, best_test))
コード例 #16
0
def main(args):
    print('==> Using settings {}'.format(args))
    device = torch.device("cuda")

    print('==> Loading dataset...')
    data_dict = data_preparation(args)

    print("==> Creating PoseNet model...")
    model_pos = model_pos_preparation(args, data_dict['dataset'], device)
    print("==> Prepare optimizer...")
    criterion = nn.MSELoss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr)

    ckpt_dir_path = path.join(
        args.checkpoint, args.posenet_name, args.keypoints,
        datetime.datetime.now().strftime('%m%d%H%M%S') + '_' + args.note)
    os.makedirs(ckpt_dir_path, exist_ok=True)
    print('==> Making checkpoint dir: {}'.format(ckpt_dir_path))

    logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'), args)
    logger.set_names([
        'epoch', 'lr', 'loss_train', 'error_h36m_p1', 'error_h36m_p2',
        'error_3dhp_p1', 'error_3dhp_p2'
    ])

    #################################################
    # ########## start training here
    #################################################
    start_epoch = 0
    error_best = None
    glob_step = 0
    lr_now = args.lr

    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now))

        # Train for one epoch
        epoch_loss, lr_now, glob_step = train(data_dict['train_loader'],
                                              model_pos,
                                              criterion,
                                              optimizer,
                                              device,
                                              args.lr,
                                              lr_now,
                                              glob_step,
                                              args.lr_decay,
                                              args.lr_gamma,
                                              max_norm=args.max_norm)

        # Evaluate
        error_h36m_p1, error_h36m_p2 = evaluate(data_dict['H36M_test'],
                                                model_pos, device)
        error_3dhp_p1, error_3dhp_p2 = evaluate(data_dict['3DHP_test'],
                                                model_pos,
                                                device,
                                                flipaug='_flip')

        # Update log file
        logger.append([
            epoch + 1, lr_now, epoch_loss, error_h36m_p1, error_h36m_p2,
            error_3dhp_p1, error_3dhp_p2
        ])

        # Update checkpoint
        if error_best is None or error_best > error_h36m_p1:
            error_best = error_h36m_p1
            save_ckpt(
                {
                    'state_dict': model_pos.state_dict(),
                    'epoch': epoch + 1
                },
                ckpt_dir_path,
                suffix='best')

        if (epoch + 1) % args.snapshot == 0:
            save_ckpt(
                {
                    'state_dict': model_pos.state_dict(),
                    'epoch': epoch + 1
                }, ckpt_dir_path)

    logger.close()
    logger.plot(['loss_train', 'error_h36m_p1'])
    savefig(path.join(ckpt_dir_path, 'log.eps'))
    return
コード例 #17
0
ファイル: main_3dpw.py プロジェクト: zibozzb/LearnTrajDep
def main(opt):
    start_epoch = 0
    err_best = 10000
    lr_now = opt.lr
    is_cuda = torch.cuda.is_available()

    script_name = os.path.basename(__file__).split('.')[0]
    script_name = script_name + '_in{:d}_out{:d}_dctn_{:d}'.format(
        opt.input_n, opt.output_n, opt.dct_n)

    # create model
    print(">>> creating model")
    input_n = opt.input_n
    output_n = opt.output_n
    dct_n = opt.dct_n

    model = nnmodel.GCN(input_feature=dct_n,
                        hidden_feature=opt.linear_size,
                        p_dropout=opt.dropout,
                        num_stage=opt.num_stage,
                        node_n=69)

    if is_cuda:
        model.cuda()

    print(">>> total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    if opt.is_load:
        model_path_len = 'checkpoint/test/ckpt_main_last.pth.tar'
        print(">>> loading ckpt len from '{}'".format(model_path_len))
        if is_cuda:
            ckpt = torch.load(model_path_len)
        else:
            ckpt = torch.load(model_path_len, map_location='cpu')
        start_epoch = ckpt['epoch']
        err_best = ckpt['err']
        lr_now = ckpt['lr']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])
        print(">>> ckpt len loaded (epoch: {} | err: {})".format(
            start_epoch, err_best))

    # data loading
    print(">>> loading data")
    train_dataset = Pose3dPW(path_to_data=opt.data_dir_3dpw,
                             input_n=input_n,
                             output_n=output_n,
                             dct_n=dct_n,
                             split=0)
    dim_used = train_dataset.dim_used
    test_dataset = Pose3dPW(path_to_data=opt.data_dir_3dpw,
                            input_n=input_n,
                            output_n=output_n,
                            dct_n=dct_n,
                            split=1)
    val_dataset = Pose3dPW(path_to_data=opt.data_dir_3dpw,
                           input_n=input_n,
                           output_n=output_n,
                           dct_n=dct_n,
                           split=2)

    # load dadasets for training
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=opt.train_batch,
                              shuffle=True,
                              num_workers=opt.job,
                              pin_memory=True)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=opt.test_batch,
                             shuffle=False,
                             num_workers=opt.job,
                             pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=opt.test_batch,
                            shuffle=False,
                            num_workers=opt.job,
                            pin_memory=True)
    print(">>> data loaded !")
    print(">>> train data {}".format(train_dataset.__len__()))
    print(">>> test data {}".format(test_dataset.__len__()))
    print(">>> validation data {}".format(val_dataset.__len__()))

    for epoch in range(start_epoch, opt.epochs):

        if (epoch + 1) % opt.lr_decay == 0:
            lr_now = utils.lr_decay(optimizer, lr_now, opt.lr_gamma)
        print('==========================')
        print('>>> epoch: {} | lr: {:.5f}'.format(epoch + 1, lr_now))
        ret_log = np.array([epoch + 1])
        head = np.array(['epoch'])
        # per epoch
        lr_now, t_l, t_err = train(train_loader,
                                   model,
                                   optimizer,
                                   input_n=input_n,
                                   dct_n=dct_n,
                                   dim_used=dim_used,
                                   lr_now=lr_now,
                                   max_norm=opt.max_norm,
                                   is_cuda=is_cuda)
        ret_log = np.append(ret_log, [lr_now, t_l, t_err])
        head = np.append(head, ['lr', 't_l', 't_err'])

        v_err = val(val_loader,
                    model,
                    input_n=input_n,
                    dct_n=dct_n,
                    dim_used=dim_used,
                    is_cuda=is_cuda)

        ret_log = np.append(ret_log, v_err)
        head = np.append(head, ['v_err'])

        test_3d = test(test_loader,
                       model,
                       input_n=input_n,
                       output_n=output_n,
                       dct_n=dct_n,
                       dim_used=dim_used,
                       is_cuda=is_cuda)
        # ret_log = np.append(ret_log, test_l)
        ret_log = np.append(ret_log, test_3d)
        if output_n == 15:
            head = np.append(head,
                             ['1003d', '2003d', '3003d', '4003d', '5003d'])
        elif output_n == 30:
            head = np.append(head, [
                '1003d', '2003d', '3003d', '4003d', '5003d', '6003d', '7003d',
                '8003d', '9003d', '10003d'
            ])

        # update log file
        df = pd.DataFrame(np.expand_dims(ret_log, axis=0))
        if epoch == start_epoch:
            df.to_csv(opt.ckpt + '/' + script_name + '.csv',
                      header=head,
                      index=False)
        else:
            with open(opt.ckpt + '/' + script_name + '.csv', 'a') as f:
                df.to_csv(f, header=False, index=False)
        # save ckpt
        is_best = v_err < err_best
        err_best = min(v_err, err_best)
        file_name = [
            'ckpt_' + script_name + '_best.pth.tar',
            'ckpt_' + script_name + '_last.pth.tar'
        ]
        utils.save_ckpt(
            {
                'epoch': epoch + 1,
                'lr': lr_now,
                'err': test_3d[0],
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
            ckpt_path=opt.ckpt,
            is_best=is_best,
            file_name=file_name)
コード例 #18
0
def main(opt):
    start_epoch = 0
    err_best = 10000
    lr_now = opt.lr
    is_cuda = torch.cuda.is_available()

    print(">>> loading data")
    input_n = opt.input_n
    output_n = opt.output_n
    dct_n = opt.dct_n
    sample_rate = opt.sample_rate

    #####################################################
    # Load data
    #####################################################
    data = DATA(opt.dataset, opt.data_dir)
    out_of_distribution = data.get_dct_and_sequences(input_n, output_n,
                                                     sample_rate, dct_n,
                                                     opt.out_of_distribution)
    train_loader, val_loader, OoD_val_loader, test_loaders = data.get_dataloaders(
        opt.train_batch, opt.test_batch, opt.job)
    print(">>> data loaded !")
    print(">>> train data {}".format(data.train_dataset.__len__()))
    if opt.dataset == 'h3.6m':
        print(">>> validation data {}".format(data.val_dataset.__len__()))

    #####################################################
    # Define script name
    #####################################################
    script_name = os.path.basename(__file__).split('.')[0]
    script_name = script_name + "_{}_in{:d}_out{:d}_dctn{:d}_dropout_{}".format(
        str(opt.dataset), opt.input_n, opt.output_n, opt.dct_n, str(
            opt.dropout))
    if out_of_distribution:
        script_name = script_name + "_OoD_{}_".format(
            str(opt.out_of_distribution))
    if opt.variational:
        script_name = script_name + "_var_lambda_{}_nz_{}_lr_{}_n_layers_{}".format(
            str(opt.lambda_), str(opt.n_z), str(opt.lr),
            str(opt.num_decoder_stage))

    ##################################################################
    # Instantiate model, and methods used fro training and valdation
    ##################################################################
    print(">>> creating model")
    model = nnmodel.GCN(input_feature=dct_n,
                        hidden_feature=opt.linear_size,
                        p_dropout=opt.dropout,
                        num_stage=opt.num_stage,
                        node_n=data.node_n,
                        variational=opt.variational,
                        n_z=opt.n_z,
                        num_decoder_stage=opt.num_decoder_stage)
    methods = MODEL_METHODS(model, is_cuda)
    if opt.is_load:
        start_epoch, err_best, lr_now = methods.load_weights(opt.load_path)
    print(">>> total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))
    methods.optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    for epoch in range(start_epoch, opt.epochs):
        #####################################################################################################################################################
        # Training step
        #####################################################################################################################################################
        if (epoch + 1) % opt.lr_decay == 0:
            lr_now = utils.lr_decay(methods.optimizer, lr_now, opt.lr_gamma)
        print('==========================')
        print('>>> epoch: {} | lr: {:.5f}'.format(epoch + 1, lr_now))
        ret_log = np.array([epoch + 1])
        head = np.array(['epoch'])
        # per epoch
        lr_now, t_l, t_l_joint, t_l_vlb, t_l_latent, t_e, t_3d = methods.train(
            train_loader,
            dataset=opt.dataset,
            input_n=input_n,
            lr_now=lr_now,
            cartesian=data.cartesian,
            lambda_=opt.lambda_,
            max_norm=opt.max_norm,
            dim_used=data.train_dataset.dim_used,
            dct_n=dct_n)
        ret_log = np.append(
            ret_log, [lr_now, t_l, t_l_joint, t_l_vlb, t_l_latent, t_e, t_3d])
        head = np.append(
            head,
            ['lr', 't_l', 't_l_joint', 't_l_vlb', 't_l_latent', 't_e', 't_3d'])

        #####################################################################################################################################################
        # Evaluate on validation set; Keep track of best, either via val set, OoD val set (in the case of OoD), or train set in the case of the CMU dataset
        #####################################################################################################################################################
        if opt.dataset == 'h3.6m':
            v_e, v_3d = methods.val(val_loader,
                                    input_n=input_n,
                                    dim_used=data.train_dataset.dim_used,
                                    dct_n=dct_n)
            ret_log = np.append(ret_log, [v_e, v_3d])
            head = np.append(head, ['v_e', 'v_3d'])

            is_best, err_best = utils.check_is_best(v_e, err_best)
            if out_of_distribution:
                OoD_v_e, OoD_v_3d = methods.val(
                    OoD_val_loader,
                    input_n=input_n,
                    dim_used=data.train_dataset.dim_used,
                    dct_n=dct_n)
                ret_log = np.append(ret_log, [OoD_v_e, OoD_v_3d])
                head = np.append(head, ['OoD_v_e', 'OoD_v_3d'])
        else:
            is_best, err_best = utils.check_is_best(t_e, err_best)

        #####################################################
        # Evaluate on test set
        #####################################################
        test_3d_temp = np.array([])
        test_3d_head = np.array([])
        for act in data.acts_test:
            test_e, test_3d = methods.test(
                test_loaders[act],
                dataset=opt.dataset,
                input_n=input_n,
                output_n=output_n,
                cartesian=data.cartesian,
                dim_used=data.train_dataset.dim_used,
                dct_n=dct_n)
            ret_log = np.append(ret_log, test_e)
            test_3d_temp = np.append(test_3d_temp, test_3d)
            test_3d_head = np.append(
                test_3d_head,
                [act + '3d80', act + '3d160', act + '3d320', act + '3d400'])
            head = np.append(
                head, [act + '80', act + '160', act + '320', act + '400'])
            if output_n > 10:
                head = np.append(head, [act + '560', act + '1000'])
                test_3d_head = np.append(test_3d_head,
                                         [act + '3d560', act + '3d1000'])
        ret_log = np.append(ret_log, test_3d_temp)
        head = np.append(head, test_3d_head)

        #####################################################
        # Update log file and save checkpoint
        #####################################################
        df = pd.DataFrame(np.expand_dims(ret_log, axis=0))
        if epoch == start_epoch:
            df.to_csv(opt.ckpt + '/' + script_name + '.csv',
                      header=head,
                      index=False)
        else:
            with open(opt.ckpt + '/' + script_name + '.csv', 'a') as f:
                df.to_csv(f, header=False, index=False)
        file_name = [
            'ckpt_' + script_name + '_best.pth.tar',
            'ckpt_' + script_name + '_last.pth.tar'
        ]
        utils.save_ckpt(
            {
                'epoch': epoch + 1,
                'lr': lr_now,
                'err': test_e[0],
                'state_dict': model.state_dict(),
                'optimizer': methods.optimizer.state_dict()
            },
            ckpt_path=opt.ckpt,
            is_best=is_best,
            file_name=file_name)
コード例 #19
0
def train():
    train_data = ACNet_data.SUNRGBD(transform=transforms.Compose([
        ACNet_data.scaleNorm(),
        ACNet_data.RandomScale((1.0, 1.4)),
        ACNet_data.RandomHSV((0.9, 1.1), (0.9, 1.1), (25, 25)),
        ACNet_data.RandomCrop(image_h, image_w),
        ACNet_data.RandomFlip(),
        ACNet_data.ToTensor(),
        ACNet_data.Normalize()
    ]),
                                    phase_train=True,
                                    data_dir=args.data_dir)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=False)

    num_train = len(train_data)

    if args.last_ckpt:
        model = ACNet_models_V1.ACNet(num_class=40, pretrained=False)
    else:
        model = ACNet_models_V1.ACNet(num_class=40, pretrained=True)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    # CEL_weighted = utils.CrossEntropyLoss2d()
    CEL_weighted = utils.FocalLoss2d(weight=nyuv2_frq, gamma=2)
    model.train()
    model.to(device)
    CEL_weighted.to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    global_step = 0

    if args.last_ckpt:
        global_step, args.start_epoch = load_ckpt(model, optimizer,
                                                  args.last_ckpt, device)

    #hxx for finetuing
    # lr_decay_lambda = lambda epoch: 0.2 * args.lr_decay_rate ** ((epoch - args.start_epoch) // args.lr_epoch_per_decay)
    lr_decay_lambda = lambda epoch: (1 - (epoch - args.start_epoch) /
                                     (args.epochs - args.start_epoch))**0.9
    scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda)

    writer = SummaryWriter(args.summary_dir)

    for epoch in range(int(args.start_epoch), args.epochs):
        # if (epoch - args.start_epoch) % args.lr_epoch_per_decay == 0:
        scheduler.step(epoch)
        local_count = 0
        last_count = 0
        end_time = time.time()
        if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch:
            save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch,
                      local_count, num_train)

        for batch_idx, sample in enumerate(train_loader):

            image = sample['image'].to(device)
            depth = sample['depth'].to(device)
            target_scales = [
                sample[s].to(device)
                for s in ['label', 'label2', 'label3', 'label4', 'label5']
            ]
            optimizer.zero_grad()
            pred_scales = model(image, depth, args.checkpoint)
            loss = CEL_weighted(pred_scales, target_scales)
            loss.backward()
            optimizer.step()
            local_count += image.data.shape[0]
            global_step += 1
            if global_step % args.print_freq == 0 or global_step == 1:

                time_inter = time.time() - end_time
                count_inter = local_count - last_count
                print_log(global_step, epoch, local_count, count_inter,
                          num_train, loss, time_inter)
                end_time = time.time()

                last_count = local_count

    save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs, 0,
              num_train)

    print("Training completed ")
コード例 #20
0
def train_rl(model, dataloaders):
    train_dataloader, dev_dataloader, test_dataloader = dataloaders

    clf_criterion = nn.BCEWithLogitsLoss()
    mle_criterion = nn.CrossEntropyLoss(ignore_index=constant.pad_idx)
    baseline_criterion = nn.MSELoss()

    if constant.optim == 'Adam':
        opt = torch.optim.Adam(model.parameters(), lr=constant.lr)
    elif constant.optim == 'SGD':
        opt = torch.optim.SGD(model.parameters(), lr=constant.lr)
    else:
        print("Optim is not defined")
        exit(1)

    start_epoch = 1
    if constant.restore:
        model, opt, start_epoch = load_ckpt(model, opt, constant.restore_path)

    if constant.USE_CUDA:
        model.cuda()

    best_dev = 0
    best_path = ''
    patience = 3
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,
                                                           'max',
                                                           factor=0.5,
                                                           patience=0,
                                                           min_lr=1e-6)
    tau = constant.tau
    tau_min = 0.2
    tau_dec = 0.2
    pretrain_curiosity = constant.lambda_aux
    if constant.pretrain_curiosity:
        pretrain_curiosity = 0.0

    try:
        for e in range(start_epoch, constant.epochs):
            model.train()
            reward_log = []
            ori_reward_log = []
            aux_reward_log = []  # for sentiment agreement / curiosity
            inv_loss_log = []  # for curiosity
            f1_log = []

            # pretrain curiosity only for first epoch
            if e > start_epoch:
                pretrain_curiosity = constant.lambda_aux

            # temperature annealing
            if constant.use_tau_anneal and e > start_epoch and constant.tau > tau_min:
                constant.tau -= tau_dec

            if constant.grid_search:
                pbar = enumerate(train_dataloader)
            else:
                pbar = tqdm(enumerate(train_dataloader),
                            total=len(train_dataloader))

            for b, (dialogs, lens, targets, _, _, sentiments, sentiments_b, _,
                    _) in pbar:
                if len(train_dataloader) % (b + 1) == 10:
                    torch.cuda.empty_cache()
                opt.zero_grad()
                try:
                    B, T = targets.shape

                    if constant.use_self_critical:
                        step_loss, dec_lens_var, R_g, R, greedy_sents, sampled_sents = model(
                            dialogs, lens, targets)
                        # (R_s - R_g) * step_loss
                        rl_loss = torch.mean(
                            torch.sum((R.detach() - R_g.detach()) * step_loss,
                                      dim=1) / dec_lens_var.float())
                    elif constant.use_arl:
                        step_loss, dec_lens_var, rs, R, arl, sampled_sents = model(
                            dialogs, lens, targets)
                        rs = rs.transpose(0, 1).contiguous()

                        rl_loss = (R.detach() - rs.detach()) * step_loss
                        rl_loss = torch.mean(
                            torch.sum(rl_loss * arl, dim=1) /
                            dec_lens_var.float())
                    else:
                        # probs: (B, T, V), xs: (B, T), R: (B, 1), rs: (B, T)
                        if constant.use_sentiment and constant.aux_reward_model != '':
                            step_loss, dec_lens_var, rs, R_l, R_s, sampled_sents, clf_logits = model(
                                dialogs, lens, targets, sentiments=sentiments)
                            R = constant.lambda_aux * R_l + R_s
                            clf_loss = clf_criterion(clf_logits, sentiments_b)
                            preds = torch.sigmoid(clf_logits.squeeze()) > 0.5
                            f1 = f1_score(sentiments_b.cpu().numpy(),
                                          preds.detach().cpu().numpy(),
                                          average='weighted')
                            f1_log.append(f1)
                        elif constant.use_sentiment and constant.use_sentiment_agreement:
                            step_loss, dec_lens_var, rs, R, sampled_sents, clf_logits = model(
                                dialogs, lens, targets, sentiments=sentiments)
                            clf_loss = clf_criterion(clf_logits, sentiments_b)
                            preds = torch.sigmoid(clf_logits.squeeze()) > 0.5
                            f1 = f1_score(sentiments_b.cpu().numpy(),
                                          preds.detach().cpu().numpy(),
                                          average='weighted')
                            f1_log.append(f1)
                        elif constant.use_sentiment_agreement:
                            step_loss, dec_lens_var, rs, R, sampled_sents = model(
                                dialogs, lens, targets, sentiments=sentiments)
                        elif constant.use_sentiment:
                            step_loss, dec_lens_var, rs, R, sampled_sents, clf_logits = model(
                                dialogs, lens, targets, sentiments=sentiments)
                            clf_loss = clf_criterion(clf_logits, sentiments_b)
                            preds = torch.sigmoid(clf_logits.squeeze()) > 0.5
                            f1 = f1_score(sentiments_b.cpu().numpy(),
                                          preds.detach().cpu().numpy(),
                                          average='weighted')
                            f1_log.append(f1)
                        elif constant.use_curiosity:
                            step_loss, dec_lens_var, rs, R, R_i, L_i, sampled_sents = model(
                                dialogs, lens, targets)
                            rs = rs.transpose(0, 1).contiguous()
                            R_i = R_i.transpose(0, 1).contiguous()
                            baseline_target = R.detach() * R_i.detach()
                            rl_loss = torch.mean(
                                torch.sum(
                                    (R.detach() * R_i.detach() - rs.detach()) *
                                    step_loss,
                                    dim=1) / dec_lens_var.float())
                            R_i = torch.mean(
                                torch.sum(R_i, dim=1) / dec_lens_var.float())
                        else:
                            step_loss, dec_lens_var, rs, R, sampled_sents = model(
                                dialogs, lens, targets)

                        if not constant.use_curiosity:
                            # probs = probs.transpose(0, 1).cntiguous()
                            # xs = xs.transpose(0, 1).contiguous()
                            # # (B, T, V) => (B, T) => (B,)
                            # probs = torch.gather(probs, dim=2, index=xs.unsqueeze(2)).squeeze()
                            # probs = -torch.log(probs)
                            rs = rs.transpose(0, 1).contiguous()
                            rl_loss = torch.mean(
                                torch.sum(
                                    (R.detach() - rs.detach()) * step_loss,
                                    dim=1) / dec_lens_var.float())

                    if constant.use_hybrid:
                        probs, _ = model(dialogs, lens, targets, use_mle=True)
                        mle_loss = mle_criterion(
                            probs.transpose(0, 1).contiguous().view(B * T, -1),
                            targets.contiguous().view(B * T))
                        loss = constant.lambda_mle * rl_loss + (
                            1 - constant.lambda_mle) * mle_loss
                    elif constant.use_arl:
                        probs, _ = model(dialogs, lens, targets, use_mle=True)
                        arl_c = torch.ones(arl.size()).to(arl.device) - arl
                        mle_criterion.reduction = 'none'
                        mle_loss = mle_criterion(
                            probs.transpose(0, 1).contiguous().view(B * T, -1),
                            targets.contiguous().view(B * T))
                        mle_loss = torch.mean(
                            torch.sum(mle_loss * arl_c, dim=1))
                        loss = rl_loss + mle_loss
                    else:
                        loss = rl_loss

                    if constant.use_sentiment:
                        loss = constant.lambda_emo * clf_loss + (
                            1 - constant.lambda_emo) * loss

                    if constant.use_curiosity:
                        loss = pretrain_curiosity * loss + (
                            1 - constant.beta) * L_i + constant.beta * R_i

                    loss.backward()
                    opt.step()

                    if constant.use_baseline:
                        if constant.use_curiosity:
                            baseline_loss = baseline_criterion(
                                rs, baseline_target)
                        else:
                            # rs (32, T) <==> R (32, 1)
                            baseline_loss = baseline_criterion(
                                rs, tile(R, T, dim=1))
                        baseline_loss.backward()
                        opt.step()

                    ## logging
                    reward_log.append(torch.mean(R).item())
                    if constant.use_sentiment and constant.aux_reward_model != '':
                        ori_reward_log.append(torch.mean(R_l).item())
                        aux_reward_log.append(torch.mean(R_s).item())

                    if constant.use_curiosity:
                        aux_reward_log.append(torch.mean(R_i).item())
                        inv_loss_log.append(L_i.item())

                    if not constant.grid_search:
                        if constant.use_sentiment:
                            if constant.aux_reward_model != '':
                                pbar.set_description(
                                    "(Epoch {}) TRAIN R: {:.3f} R_l: {:.3f} R_s: {:.3f} F1: {:.3f}"
                                    .format(e, np.mean(reward_log),
                                            np.mean(ori_reward_log),
                                            np.mean(aux_reward_log),
                                            np.mean(f1_log)))
                            else:
                                pbar.set_description(
                                    "(Epoch {}) TRAIN REWARD: {:.4f} TRAIN F1: {:.4f}"
                                    .format(e, np.mean(reward_log),
                                            np.mean(f1_log)))
                        elif constant.use_curiosity:
                            pbar.set_description(
                                "(Epoch {}) TRAIN R: {:.3f} R_i: {:.3f} L_i: {:.3f}"
                                .format(e, np.mean(reward_log),
                                        np.mean(aux_reward_log),
                                        np.mean(inv_loss_log)))
                        else:
                            pbar.set_description(
                                "(Epoch {}) TRAIN REWARD: {:.4f}".format(
                                    e, np.mean(reward_log)))

                    if b % 100 == 0 and b > 0:
                        # if not constant.use_self_critical:
                        #     _, greedy_sents = model(dialogs, lens, targets, test=True, use_mle=True)
                        corrects = [
                            " ".join([
                                train_dataloader.dataset.lang.index2word[x_t]
                                for x_t in iter(lambda x=iter(gens): next(x),
                                                constant.eou_idx)
                            ]) for gens in targets.cpu().data.numpy()
                        ]
                        contexts = [
                            " ".join([
                                train_dataloader.dataset.lang.index2word[x_t]
                                for x_t in iter(lambda x=iter(gens): next(x),
                                                constant.pad_idx)
                            ]) for gens in dialogs.cpu().data.numpy()
                        ]
                        for d, c, s, r in zip(contexts, corrects,
                                              sampled_sents,
                                              R.detach().cpu().numpy()):
                            print('reward: ', r)
                            print('dialog: ', d)
                            print('sample: ', s)
                            print('golden: ', c)
                            print()
                except RuntimeError as err:
                    if 'out of memory' in str(err):
                        print('| WARNING: ran out of memory, skipping batch')
                        torch.cuda.empty_cache()
                    else:
                        print(err)
                        traceback.print_exc()
                        raise err

            ## LOG
            if constant.use_sentiment and not constant.use_sentiment_agreement:
                dev_reward, dev_f1 = eval_rl(model, dev_dataloader, bleu=False)
                print("(Epoch {}) DEV REWARD: {:.4f}".format(e, dev_reward))
            elif constant.use_curiosity:
                dev_reward, dev_Ri, dev_Li = eval_rl(model,
                                                     dev_dataloader,
                                                     bleu=False)
                print("(Epoch {}) DEV REWARD: {:.3f} R_i: {:.3f} L_i: {:.3f}".
                      format(e, dev_reward, dev_Ri, dev_Li))
            else:
                dev_reward = eval_rl(model, dev_dataloader, bleu=False)
                print("(Epoch {}) DEV REWARD: {:.4f}".format(e, dev_reward))

            scheduler.step(dev_reward)
            if (dev_reward > best_dev):
                best_dev = dev_reward
                # save best model
                path = 'trained/data-{}.task-rlseq.lr-{}.tau-{}.lambda-{}.reward-{}.{}'
                path = path.format(constant.data, constant.lr, tau,
                                   constant.lambda_mle, best_dev,
                                   constant.reward_model.split('/')[1])
                if constant.use_curiosity:
                    path += '.curiosity'
                if constant.aux_reward_model != '':
                    path += '.' + constant.aux_reward_model.split('/')[1]
                    path += '.lambda_aux-{}'.format(constant.lambda_aux)
                if constant.use_tau_anneal:
                    path += '.tau_anneal'
                if constant.use_self_critical:
                    path += '.self_critical'
                if constant.use_current:
                    path += '.current'
                if constant.use_sentiment:
                    path += '.sentiment'
                if constant.use_sentiment_agreement:
                    path += '.agreement'
                if constant.use_context:
                    path += '.context'
                if constant.topk:
                    path += '.topk-{}'.format(constant.topk_size)
                if constant.use_arl:
                    path += '.arl'
                if constant.grid_search:
                    path += '.grid'
                best_path = save_model(model, 'reward', best_dev, path)
                patience = 3
            else:
                patience -= 1
            if patience == 0: break
            if constant.aux_reward_model == '' and best_dev == 0.0: break

    except KeyboardInterrupt:
        if not constant.grid_search:
            print("KEYBOARD INTERRUPT: Save CKPT and Eval")
            save = True if input('Save ckpt? (y/n)\t') in [
                'y', 'Y', 'yes', 'Yes'
            ] else False
            if save:
                save_path = save_ckpt(model, opt, e)
                print("Saved CKPT path: ", save_path)
            # ask if eval
            do_eval = True if input('Proceed with eval? (y/n)\t') in [
                'y', 'Y', 'yes', 'Yes'
            ] else False
            if do_eval:
                if constant.use_sentiment:
                    if constant.aux_reward_model != '':
                        dev_rewards, dev_f1, dev_bleu, dev_bleus = eval_rl(
                            model, dev_dataloader, bleu=True)
                        print(
                            "DEV R: {:.3f} R_l: {:.3f} R_s: {:.3f} DEV F1: {:.3f} DEV B: {:.3f}"
                            .format(dev_rewards[0], dev_rewards[1],
                                    dev_rewards[2], dev_f1, dev_bleu))
                    else:
                        dev_reward, dev_f1, dev_bleu, dev_bleus = eval_rl(
                            model, dev_dataloader, bleu=True)
                        print(
                            "DEV REWARD: {:.4f}, DEV F1: {:.4f}, DEV BLEU: {:.4f}"
                            .format(dev_reward, dev_f1, dev_bleu))
                elif constant.use_curiosity:
                    dev_reward, dev_Ri, dev_Li, dev_bleu, dev_bleus = eval_rl(
                        model, dev_dataloader, bleu=True)
                    print(
                        "BEST DEV REWARD: {:.4f} R_i: {:.3f} L_i: {:.3f} BLEU: {:.4f}"
                        .format(dev_reward, dev_Ri, dev_Li, dev_bleu))
                else:
                    dev_reward, dev_bleu, dev_bleus = eval_rl(model,
                                                              dev_dataloader,
                                                              bleu=True)
                    print("DEV REWARD: {:.4f}, DEV BLEU: {:.4f}".format(
                        dev_reward, dev_bleu))
                print(
                    "BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}"
                    .format(dev_bleus[0], dev_bleus[1], dev_bleus[2],
                            dev_bleus[3]))
        exit(1)

    # load and report best model on test
    torch.cuda.empty_cache()
    model = load_model(model, best_path)
    if constant.USE_CUDA:
        model.cuda()

    if constant.use_sentiment and not constant.use_sentiment_agreement:
        if constant.aux_reward_model != '':
            dev_rewards, dev_f1, dev_bleu, dev_bleus = eval_rl(model,
                                                               dev_dataloader,
                                                               bleu=True)
            test_rewards, test_f1, test_bleu, test_bleus = eval_rl(
                model, test_dataloader, bleu=True)
            print(
                "DEV R: {:.3f} R_l: {:.3f} R_s: {:.3f} DEV F1: {:.3f} DEV B: {:.3f}"
                .format(dev_rewards[0], dev_rewards[1], dev_rewards[2], dev_f1,
                        dev_bleu))
            print(
                "BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}"
                .format(dev_bleus[0], dev_bleus[1], dev_bleus[2],
                        dev_bleus[3]))
            print(
                "TEST R: {:.3f} R_l: {:.3f} R_s: {:.3f} TEST F1: {:.3f} TEST B: {:.3f}"
                .format(test_rewards[0], test_rewards[1], test_rewards[2],
                        test_f1, test_bleu))
            print(
                "BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}"
                .format(test_bleus[0], test_bleus[1], test_bleus[2],
                        test_bleus[3]))
        else:
            dev_reward, dev_f1, dev_bleu, dev_bleus = eval_rl(model,
                                                              dev_dataloader,
                                                              bleu=True)
            test_reward, test_f1, test_bleu, test_bleus = eval_rl(
                model, test_dataloader, bleu=True)
            print(
                "DEV REWARD: {:.4f}, DEV F1: {:.4f}, DEV BLEU: {:.4f}".format(
                    dev_reward, dev_f1, dev_bleu))
            print(
                "BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}"
                .format(dev_bleus[0], dev_bleus[1], dev_bleus[2],
                        dev_bleus[3]))
            print("TEST REWARD: {:.4f}, TEST F1: {:.4f}, TEST BLEU: {:.4f}".
                  format(test_reward, test_f1, test_bleu))
            print(
                "BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}"
                .format(test_bleus[0], test_bleus[1], test_bleus[2],
                        test_bleus[3]))
    elif constant.use_curiosity:
        dev_reward, dev_Ri, dev_Li, dev_bleu, dev_bleus = eval_rl(
            model, dev_dataloader, bleu=True)
        test_reward, test_Ri, test_Li, test_bleu, test_bleus = eval_rl(
            model, test_dataloader, bleu=True)
        print("BEST DEV REWARD: {:.4f} R_i: {:.3f} L_i: {:.3f} BLEU: {:.4f}".
              format(dev_reward, dev_Ri, dev_Li, dev_bleu))
        print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".
              format(dev_bleus[0], dev_bleus[1], dev_bleus[2], dev_bleus[3]))
        print("BEST TEST REWARD: {:.4f} R_i: {:.3f} L_i: {:.3f} BLEU: {:.4f}".
              format(test_reward, test_Ri, test_Li, test_bleu))
        print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".
              format(test_bleus[0], test_bleus[1], test_bleus[2],
                     test_bleus[3]))
    else:
        dev_reward, dev_bleu, dev_bleus = eval_rl(model,
                                                  dev_dataloader,
                                                  bleu=True)
        test_reward, test_bleu, test_bleus = eval_rl(model,
                                                     test_dataloader,
                                                     bleu=True)
        print("BEST DEV REWARD: {:.4f}, BLEU: {:.4f}".format(
            dev_reward, dev_bleu))
        print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".
              format(dev_bleus[0], dev_bleus[1], dev_bleus[2], dev_bleus[3]))
        print("BEST TEST REWARD: {:.4f}, BLEU: {:.4f}".format(
            test_reward, test_bleu))
        print("BLEU 1: {:.4f}, BLEU 2: {:.4f}, BLEU 3: {:.4f}, BLEU 4: {:.4f}".
              format(test_bleus[0], test_bleus[1], test_bleus[2],
                     test_bleus[3]))
コード例 #21
0
def main(args):
    print('==> Using settings {}'.format(args))
    device = torch.device("cuda")

    print('==> Loading dataset...')
    data_dict = data_preparation(args)

    print("==> Creating PoseNet model...")
    model_pos = model_pos_preparation(args, data_dict['dataset'], device)
    model_pos_eval = model_pos_preparation(args, data_dict['dataset'],
                                           device)  # used for evaluation only
    # prepare optimizer for posenet
    posenet_optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr_p)
    posenet_lr_scheduler = get_scheduler(posenet_optimizer,
                                         policy='lambda',
                                         nepoch_fix=0,
                                         nepoch=args.epochs)

    print("==> Creating PoseAug model...")
    poseaug_dict = get_poseaug_model(args, data_dict['dataset'])

    # loss function
    criterion = nn.MSELoss(reduction='mean').to(device)

    # GAN trick: data buffer for fake data
    fake_3d_sample = Sample_from_Pool()
    fake_2d_sample = Sample_from_Pool()

    args.checkpoint = path.join(
        args.checkpoint, args.posenet_name, args.keypoints,
        datetime.datetime.now().isoformat() + '_' + args.note)
    os.makedirs(args.checkpoint, exist_ok=True)
    print('==> Making checkpoint dir: {}'.format(args.checkpoint))

    logger = Logger(os.path.join(args.checkpoint, 'log.txt'), args)
    logger.record_args(str(model_pos))
    logger.set_names([
        'epoch', 'lr', 'error_h36m_p1', 'error_h36m_p2', 'error_3dhp_p1',
        'error_3dhp_p2'
    ])

    # Init monitor for net work training
    #########################################################
    summary = Summary(args.checkpoint)
    writer = summary.create_summary()

    ##########################################################
    # start training
    ##########################################################
    start_epoch = 0
    dhpp1_best = None
    s911p1_best = None

    for _ in range(start_epoch, args.epochs):

        if summary.epoch == 0:
            # evaluate the pre-train model for epoch 0.
            h36m_p1, h36m_p2, dhp_p1, dhp_p2 = evaluate_posenet(args,
                                                                data_dict,
                                                                model_pos,
                                                                model_pos_eval,
                                                                device,
                                                                summary,
                                                                writer,
                                                                tag='_fake')
            h36m_p1, h36m_p2, dhp_p1, dhp_p2 = evaluate_posenet(args,
                                                                data_dict,
                                                                model_pos,
                                                                model_pos_eval,
                                                                device,
                                                                summary,
                                                                writer,
                                                                tag='_real')
            summary.summary_epoch_update()

        # update train loader
        dataloader_update(args=args, data_dict=data_dict, device=device)

        # Train for one epoch
        train_gan(args, poseaug_dict, data_dict, model_pos, criterion,
                  fake_3d_sample, fake_2d_sample, summary, writer)

        if summary.epoch > args.warmup:
            train_posenet(model_pos, data_dict['train_fake2d3d_loader'],
                          posenet_optimizer, criterion, device)
            h36m_p1, h36m_p2, dhp_p1, dhp_p2 = evaluate_posenet(args,
                                                                data_dict,
                                                                model_pos,
                                                                model_pos_eval,
                                                                device,
                                                                summary,
                                                                writer,
                                                                tag='_fake')

            train_posenet(model_pos, data_dict['train_det2d3d_loader'],
                          posenet_optimizer, criterion, device)
            h36m_p1, h36m_p2, dhp_p1, dhp_p2 = evaluate_posenet(args,
                                                                data_dict,
                                                                model_pos,
                                                                model_pos_eval,
                                                                device,
                                                                summary,
                                                                writer,
                                                                tag='_real')
        # Update learning rates
        ########################
        poseaug_dict['scheduler_G'].step()
        poseaug_dict['scheduler_d3d'].step()
        poseaug_dict['scheduler_d2d'].step()
        posenet_lr_scheduler.step()
        lr_now = posenet_optimizer.param_groups[0]['lr']
        print('\nEpoch: %d | LR: %.8f' % (summary.epoch, lr_now))

        # Update log file
        logger.append(
            [summary.epoch, lr_now, h36m_p1, h36m_p2, dhp_p1, dhp_p2])

        # Update checkpoint
        if dhpp1_best is None or dhpp1_best > dhp_p1:
            dhpp1_best = dhp_p1
            logger.record_args(
                "==> Saving checkpoint at epoch '{}', with dhp_p1 {}".format(
                    summary.epoch, dhpp1_best))
            save_ckpt(
                {
                    'epoch': summary.epoch,
                    'model_pos': model_pos.state_dict()
                },
                args.checkpoint,
                suffix='best_dhp_p1')

        if s911p1_best is None or s911p1_best > h36m_p1:
            s911p1_best = h36m_p1
            logger.record_args(
                "==> Saving checkpoint at epoch '{}', with s911p1 {}".format(
                    summary.epoch, s911p1_best))
            save_ckpt(
                {
                    'epoch': summary.epoch,
                    'model_pos': model_pos.state_dict()
                },
                args.checkpoint,
                suffix='best_h36m_p1')

        summary.summary_epoch_update()

    writer.close()
    logger.close()
コード例 #22
0
def train_emotion(model, dataloaders):
    """ 
    Training loop
    Inputs:
        model: the model to be trained
        dataloader: data loader
    Output:
        best_dev: best f1 score on dev data
        best_test: best f1 score on test data
    """
    train_dataloader, dev_dataloader, test_dataloader = dataloaders
    if (constant.USE_CUDA): model.cuda()
    criterion = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=constant.lr)

    best_dev = 0
    best_test = 0
    patience = 3

    try:
        for e in range(constant.epochs):
            model.train()
            loss_log = []
            f1_log = []

            pbar = tqdm(enumerate(train_dataloader),
                        total=len(train_dataloader))

            if constant.grid_search:
                pbar = enumerate(train_dataloader)
            else:
                pbar = tqdm(enumerate(train_dataloader),
                            total=len(train_dataloader))

            for _, (dialogs, lens, _, _, emotions, _, _, _) in pbar:
                opt.zero_grad()
                logits = model(dialogs, lens)

                loss = criterion(logits, emotions)
                loss.backward()
                opt.step()

                ## logging
                loss_log.append(loss.item())
                preds = torch.argmax(logits, dim=1)
                f1 = f1_score(emotions.cpu().numpy(),
                              preds.detach().cpu().numpy(),
                              average='weighted')
                # _, _, _, microF1 = get_metrics(logits.detach().cpu().numpy(), emotions.cpu().numpy())
                f1_log.append(f1)
                if not constant.grid_search:
                    pbar.set_description(
                        "(Epoch {}) TRAIN F1:{:.4f} TRAIN LOSS:{:.4f}".format(
                            e + 1, np.mean(f1_log), np.mean(loss_log)))

            ## LOG
            f1 = eval_emotion(model, dev_dataloader)
            testF1 = eval_emotion(model, test_dataloader)
            print("(Epoch {}) DEV F1: {:.4f} TEST F1: {:.4f}".format(
                e + 1, f1, testF1))
            print("(Epoch {}) BEST DEV F1: {:.4f} BEST TEST F1: {:.4f}".format(
                e + 1, best_dev, best_test))
            if (f1 > best_dev):
                best_dev = f1
                best_test = testF1
                patience = 3
            else:
                patience -= 1
            if (patience == 0): break
            if (best_dev == 1.0): break

    except KeyboardInterrupt:
        if not constant.grid_search:
            print("KEYBOARD INTERRUPT: Save CKPT and Eval")
            save = True if input('Save ckpt? (y/n)\t') in [
                'y', 'Y', 'yes', 'Yes'
            ] else False
            if save:
                save_path = save_ckpt(model, opt, e)
                print("Saved CKPT path: ", save_path)
            print("BEST SCORES - DEV F1: {:.4f}, TEST F1: {:.4f}".format(
                best_dev, best_test))
        exit(1)

    print("BEST SCORES - DEV F1: {:.4f}, TEST F1: {:.4f}".format(
        best_dev, best_test))
コード例 #23
0
ファイル: main.py プロジェクト: YiqunChen1999/DLTemplate
def main():
    # Set logger to record information.
    utils.check_env(cfg)
    logger = Logger(cfg)
    logger.log_info(cfg)
    metrics_handler = MetricsHandler(cfg.metrics)
    # utils.pack_code(cfg, logger=logger)

    # Build model.
    model = model_builder.build_model(cfg=cfg, logger=logger)
    optimizer = optimizer_helper.build_optimizer(cfg=cfg, model=model)
    lr_scheduler = lr_scheduler_helper.build_scheduler(cfg=cfg,
                                                       optimizer=optimizer)

    # Read checkpoint.
    ckpt = torch.load(cfg.model.path2ckpt) if cfg.gnrl.resume else {}
    if cfg.gnrl.resume:
        with logger.log_info(msg="Load pre-trained model.",
                             level="INFO",
                             state=True,
                             logger=logger):
            model.load_state_dict(ckpt["model"])
            optimizer.load_state_dict(ckpt["optimizer"])
            lr_scheduler.load_state_dict(ckpt["lr_scheduler"])

    # Set device.
    model, device = utils.set_pipline(
        model, cfg) if cfg.gnrl.PIPLINE else utils.set_device(
            model, cfg.gnrl.cuda)

    resume_epoch = ckpt["epoch"] if cfg.gnrl.resume else 0
    loss_fn = loss_fn_helper.build_loss_fn(cfg=cfg)

    # Prepare dataset.
    train_loaders, valid_loaders, test_loaders = dict(), dict(), dict()
    for dataset in cfg.data.datasets:
        if cfg.data[dataset].TRAIN:
            try:
                train_loaders[dataset] = data_loader.build_data_loader(
                    cfg, dataset, "train")
            except:
                utils.notify(msg="Failed to build train loader of %s" %
                             dataset)
        if cfg.data[dataset].VALID:
            try:
                valid_loaders[dataset] = data_loader.build_data_loader(
                    cfg, dataset, "valid")
            except:
                utils.notify(msg="Failed to build valid loader of %s" %
                             dataset)
        if cfg.data[dataset].TEST:
            try:
                test_loaders[dataset] = data_loader.build_data_loader(
                    cfg, dataset, "test")
            except:
                utils.notify(msg="Failed to build test loader of %s" % dataset)

    # TODO Train, evaluate model and save checkpoint.
    for epoch in range(cfg.train.max_epoch):
        epoch += 1
        if resume_epoch >= epoch:
            continue

        eval_kwargs = {
            "epoch": epoch,
            "cfg": cfg,
            "model": model,
            "loss_fn": loss_fn,
            "device": device,
            "metrics_handler": metrics_handler,
            "logger": logger,
            "save": cfg.save.save,
        }
        train_kwargs = {
            "epoch": epoch,
            "cfg": cfg,
            "model": model,
            "loss_fn": loss_fn,
            "optimizer": optimizer,
            "device": device,
            "lr_scheduler": lr_scheduler,
            "metrics_handler": metrics_handler,
            "logger": logger,
        }
        ckpt_kwargs = {
            "epoch": epoch,
            "cfg": cfg,
            "model": model.state_dict(),
            "metrics_handler": metrics_handler,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
        }

        for dataset in cfg.data.datasets:
            if cfg.data[dataset].TRAIN:
                utils.notify("Train on %s" % dataset)
                train_one_epoch(data_loader=train_loaders[dataset],
                                **train_kwargs)

        utils.save_ckpt(path2file=cfg.model.path2ckpt, **ckpt_kwargs)

        if epoch in cfg.gnrl.ckphs:
            utils.save_ckpt(path2file=os.path.join(
                cfg.model.ckpts,
                cfg.gnrl.id + "_" + str(epoch).zfill(5) + ".pth"),
                            **ckpt_kwargs)
            for dataset in cfg.data.datasets:
                utils.notify("Evaluating test set of %s" % dataset,
                             logger=logger)
                if cfg.data[dataset].TEST:
                    evaluate(data_loader=test_loaders[dataset],
                             phase="test",
                             **eval_kwargs)

        for dataset in cfg.data.datasets:
            utils.notify("Evaluating valid set of %s" % dataset, logger=logger)
            if cfg.data[dataset].VALID:
                evaluate(data_loader=valid_loaders[dataset],
                         phase="valid",
                         **eval_kwargs)
    # End of train-valid for loop.

    eval_kwargs = {
        "epoch": epoch,
        "cfg": cfg,
        "model": model,
        "loss_fn": loss_fn,
        "device": device,
        "metrics_handler": metrics_handler,
        "logger": logger,
        "save": cfg.save.save,
    }

    for dataset in cfg.data.datasets:
        if cfg.data[dataset].VALID:
            utils.notify("Evaluating valid set of %s" % dataset, logger=logger)
            evaluate(data_loader=valid_loaders[dataset],
                     phase="valid",
                     **eval_kwargs)
    for dataset in cfg.data.datasets:
        if cfg.data[dataset].TEST:
            utils.notify("Evaluating test set of %s" % dataset, logger=logger)
            evaluate(data_loader=test_loaders[dataset],
                     phase="test",
                     **eval_kwargs)

    for dataset in cfg.data.datasets:
        if "train" in cfg.data[dataset].INFER:
            utils.notify("Inference on train set of %s" % dataset)
            inference(data_loader=train_loaders[dataset],
                      phase="infer_train",
                      **eval_kwargs)
        if "valid" in cfg.data[dataset].INFER:
            utils.notify("Inference on valid set of %s" % dataset)
            inference(data_loader=valid_loaders[dataset],
                      phase="infer_valid",
                      **eval_kwargs)
        if "test" in cfg.data[dataset].INFER:
            utils.notify("Inference on test set of %s" % dataset)
            inference(data_loader=test_loaders[dataset],
                      phase="infer_test",
                      **eval_kwargs)

    return None
コード例 #24
0
ファイル: train.py プロジェクト: ClaraBing/PyTorch-BiGAN
    def train(self):
        """Training the BiGAN"""
        if self.args.data == 'mnist':
            img_channels = 1
            self.G = Generator_small(img_channels,
                                     self.args.latent_dim,
                                     use_tanh=self.args.normalize_data).to(
                                         self.device)
            self.E = Encoder_small(img_channels, self.args.latent_dim,
                                   self.args.use_relu_z,
                                   self.args.first_filter_size).to(self.device)
            self.D = Discriminator_small(img_channels, self.args.latent_dim,
                                         self.args.wasserstein).to(self.device)
        else:
            img_channels = 3
            self.G = Generator(img_channels,
                               self.args.latent_dim,
                               use_tanh=self.args.normalize_data).to(
                                   self.device)
            self.E = Encoder(img_channels, self.args.latent_dim,
                             self.args.use_relu_z,
                             self.args.first_filter_size).to(self.device)
            self.D = Discriminator(img_channels, self.args.latent_dim,
                                   self.args.wasserstein).to(self.device)

        if self.args.pretrained_path and os.path.exists(
                self.args.pretrained_path):
            ckpt = torch.load(self.args.pretrained_path)
            self.G.load_state_dict(ckpt['G'])
            self.E.load_state_dict(ckpt['E'])
            self.D.load_state_dict(ckpt['D'])
        else:
            self.G.apply(weights_init_normal)
            self.E.apply(weights_init_normal)
            self.D.apply(weights_init_normal)

        if self.args.freeze_GD:
            # Train the encoder only, with the generator & discriminator frozen.
            self.G.eval()
            self.D.eval()
            optimizer_d = None
            if self.args.wasserstein:
                optimizer_ge = optim.RMSprop(list(self.E.parameters()),
                                             lr=self.args.lr_rmsprop)
            else:
                optimizer_ge = optim.Adam(list(self.E.parameters()),
                                          lr=self.args.lr_adam,
                                          weight_decay=1e-6)
        else:
            if self.args.wasserstein:
                optimizer_ge = optim.RMSprop(list(self.G.parameters()) +
                                             list(self.E.parameters()),
                                             lr=self.args.lr_rmsprop)
                optimizer_d = optim.RMSprop(self.D.parameters(),
                                            lr=self.args.lr_rmsprop)
            else:
                optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                          list(self.E.parameters()),
                                          lr=self.args.lr_adam,
                                          weight_decay=1e-6)
                optimizer_d = optim.Adam(self.D.parameters(),
                                         lr=self.args.lr_adam,
                                         weight_decay=1e-6)

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs):
            ge_losses = 0
            d_losses = 0
            for x, xi in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Noise for improving training.
                if epoch < self.args.num_epochs:
                    noise1 = Variable(torch.Tensor(x.size()).normal_(
                        0, 0.1 * (self.args.num_epochs - epoch) /
                        self.args.num_epochs),
                                      requires_grad=False).to(self.device)
                    noise2 = Variable(torch.Tensor(x.size()).normal_(
                        0, 0.1 * (self.args.num_epochs - epoch) /
                        self.args.num_epochs),
                                      requires_grad=False).to(self.device)
                else:
                    # NOTE: added by BB: else the above reports error about std=0 in the last epoch
                    noise1, noise2 = 0, 0

                #Cleaning gradients.
                if optimizer_d:
                    optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_fake = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_fake = self.G(z_fake)

                #Encoder:
                x_true = x.float().to(self.device)
                # BB's NOTE: x_true has values in [0, 1]
                z_true = self.E(x_true)

                #Discriminator
                out_true = self.D(x_true + noise1, z_true)
                out_fake = self.D(x_fake + noise2, z_fake)

                #Losses
                if self.args.wasserstein:
                    loss_d = -torch.mean(out_true) + torch.mean(out_fake)
                else:
                    loss_d = criterion(out_true, y_true) + criterion(
                        out_fake, y_fake)

                #Computing gradients and backpropagate.
                loss_d.backward()
                if optimizer_d:
                    optimizer_d.step()

                #Cleaning gradients.
                optimizer_ge.zero_grad()

                #Generator:
                z_fake = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_fake = self.G(z_fake)

                #Encoder:
                x_true = x.float().to(self.device)
                z_true = self.E(x_true)

                #Discriminator
                out_true = self.D(x_true + noise1, z_true)
                out_fake = self.D(x_fake + noise2, z_fake)

                #Losses
                if self.args.wasserstein:
                    loss_ge = -torch.mean(out_fake) + torch.mean(out_true)
                else:
                    loss_ge = criterion(out_fake, y_true) + criterion(
                        out_true, y_fake)

                if self.args.use_l2_loss:
                    loss_ge += self.get_latent_l2_loss()
                    loss_ge += self.get_image_l2_loss(x_true)

                loss_ge.backward()
                optimizer_ge.step()

                if self.args.wasserstein:
                    for p in self.D.parameters():
                        p.data.clamp_(-self.args.clamp, self.args.clamp)

                ge_losses += loss_ge.item()
                d_losses += loss_d.item()

                if USE_WANDB:
                    wandb.log({
                        'iter': epoch * len(self.train_loader) + xi,
                        'loss_ge': loss_ge.item(),
                        'loss_d': loss_d.item(),
                    })

            if epoch % 50 == 0:
                images = self.G(fixed_z).data
                vutils.save_image(images, './images/{}_fake.png'.format(epoch))
                images_lst = [
                    wandb.Image(image.cpu().numpy().transpose(1, 2, 0) * 255,
                                caption="Epoch {}, #{}".format(epoch, ii))
                    for ii, image in enumerate(images)
                ]
                wandb.log({"examples": images_lst})
                if self.args.save_path:
                    save_ckpt(
                        self,
                        self.args.save_path.replace(
                            '.pt', '_tmp_e{}.pt'.format(epoch)))
                else:
                    save_ckpt(
                        self, 'ckpt_epoch{}_tmp_e{}.pt'.format(
                            self.args.num_epochs, epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))