Beispiel #1
0
def main(with_gui=None, check_stop=None):
    device_idx, model_name, generator_cfg, _, train_cfg = load_config()
    device = get_device()
    if with_gui is None:
        with_gui = train_cfg.get('with_gui')
    model = None
    loss_f = BCELoss(reduction='none')
    pause = False
    images, count, loss_sum, epoch, acc_sum, acc_sum_p = 0, 0, 0, 1, 0, 0
    generator, generator_cfg = None, None
    optimizer, optimizer_cfg = None, None
    best_loss = None
    while check_stop is None or not check_stop():

        # Check config:
        if images == 0:
            cfg = load_config()
            if model_name != cfg.model or model is None:
                model_name = cfg.model
                model, best_loss, epoch = load_model(model_name, train=True, device=device)
                log(model_name, 'Loaded model %s' % model_name)
                optimizer_cfg = None
            if optimizer_cfg != cfg.optimizer:
                optimizer_cfg = cfg.optimizer
                optimizer = create_optimizer(optimizer_cfg, model)
                log(model_name, 'Created optimizer %s' % str(optimizer))
            if generator_cfg != cfg.generator:
                generator_cfg = cfg.generator
                generator = create_generator(generator_cfg, device=device)
                log(model_name, 'Created generator')
            train_cfg = cfg.train

        # Run:
        x, target = next(generator)
        mask, _ = target.max(dim=1, keepdim=True)
        optimizer.zero_grad()
        y = model(x)

        # Save for debug:
        if train_cfg.get('save', False):
            show_images(x, 'input', save_dir='debug')
            show_images(y, 'output', mask=True, save_dir='debug')
            show_images(target, 'target', mask=mask, save_dir='debug')

        # GUI:
        if with_gui:
            if not pause:
                show_images(x, 'input')
                show_images(y, 'output', mask=True, step=2)
                show_images(target, 'target', mask=mask, step=2)
            key = cv2.waitKey(1)
            if key == ord('s'):
                torch.save(model.state_dict(), 'models/unet2.pt')
            elif key == ord('p'):
                pause = not pause
            elif key == ord('q'):
                break

        # Optimize:
        acc_sum += check_accuracy(y, target)
        loss = (loss_f(y, target) * mask).mean()
        loss_item = loss.item()
        loss_sum += loss_item
        count += 1
        images += len(x)
        loss.backward()
        optimizer.step()

        # Complete epoch:
        if images >= train_cfg['epoch_images']:
            acc_total, names = acc_sum, channel_names
            msg = 'Epoch %d: train loss %f, acc %s' % (
                epoch, loss_sum / count,
                acc_to_str(acc_total, names=names)
            )
            log(model_name, msg)
            count = 0
            images = 0
            loss_sum = 0
            epoch += 1
            acc_sum[:] = 0
            save_model(model_name, model, best_loss, epoch)

    log(model_name, 'Stopped\n')
Beispiel #2
0
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=5e-5,
                         warmup=0.1,
                         t_total=num_train_optimization_steps)

    loss_fn = BCELoss()
    model.train()
    loss_epoch = dict()

    for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
        accumulation_steps = 0
        accumulation_loss = []
        loss_batch = []
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            accumulation_steps += 1
            x, y = tuple(t.to(dev) for t in batch)
            output = model(x)
            y = y.type_as(output)
            loss = loss_fn(output.view(-1, 1), y.view(-1, 1))
            accumulation_loss.append(loss.item())
            loss.backward()
Beispiel #3
0
    try:
        print("Loading pretrained discriminator")
        weights_d = torch.load(os.path.join(folder_name, "dis_gr_ResN_batch_" + batch_size_str + "_wd"
                                            + w_decay_str + "_lr" + lrate_str + "_e" + str(epochs - 1) + ".pth"))
        d.load_state_dict(weights_d)
    except:
        raise FileNotFoundError("could not load Discriminator")

# ============================================= Training ============================================= #
filename_images = 'gen_imgs_grn_cropped_' + ResNet_str + '_LOGAN_' + str(latent_space_optimisation) + '_SelfAttn_' + str(self_attention_on) + '_Orthogonal'
if not os.path.exists(os.path.join(wd, filename_images)):
    os.mkdir(os.path.join(wd, filename_images))

tb = SummaryWriter(comment=ResNet_str + "_GAN_Orthogonal_LOGAN_batch" + str(
    batch_size) + "_wd" + w_decay_str + "_lr" + lrate_str + "epochs_" + str(num_epochs))
loss = BCELoss()
loss = loss.to(device)
img_list = []
# main train loop
if epochs >= num_epochs:
    raise SyntaxError("we have already trained for this amount of epochs")
for e in range(epochs, num_epochs):
    for id, data in dataloader:
        # print(id)
        # first, train the discriminator
        d.zero_grad()
        g.zero_grad()
        data = data.to(device)
        batch_size = data.size()[0]
        # ======================== latent optimization step if True =========================
        if latent_space_optimisation:
Beispiel #4
0
def main(args):

    # ================================================
    # Preparation
    # ================================================
    args.data_dir = os.path.expanduser(args.data_dir)
    args.result_dir = os.path.expanduser(args.result_dir)
    if args.init_model_cn != None:
        args.init_model_cn = os.path.expanduser(args.init_model_cn)
    if args.init_model_cd != None:
        args.init_model_cd = os.path.expanduser(args.init_model_cd)
    if torch.cuda.is_available() == False:
        raise Exception('At least one gpu must be available.')
    else:
        gpu = torch.device('cuda:0')

    # create result directory (if necessary)
    if os.path.exists(args.result_dir) == False:
        os.makedirs(args.result_dir)
    for s in ['phase_1', 'phase_2', 'phase_3']:
        if os.path.exists(os.path.join(args.result_dir, s)) == False:
            os.makedirs(os.path.join(args.result_dir, s))

    # dataset
    trnsfm = transforms.Compose([
        transforms.Resize(args.cn_input_size),
        transforms.RandomCrop((args.cn_input_size, args.cn_input_size)),
        transforms.ToTensor(),
    ])
    print('loading dataset... (it may take a few minutes)')
    train_dset = ImageDataset(os.path.join(args.data_dir, 'train'),
                              trnsfm,
                              recursive_search=args.recursive_search)
    test_dset = ImageDataset(os.path.join(args.data_dir, 'test'),
                             trnsfm,
                             recursive_search=args.recursive_search)
    train_loader = DataLoader(train_dset,
                              batch_size=(args.bsize // args.bdivs),
                              shuffle=True)

    # compute mean pixel value of training dataset
    mpv = np.zeros(shape=(3, ))
    if args.mpv == None:
        pbar = tqdm(total=len(train_dset.imgpaths),
                    desc='computing mean pixel value for training dataset...')
        for imgpath in train_dset.imgpaths:
            img = Image.open(imgpath)
            x = np.array(img, dtype=np.float32) / 255.
            mpv += x.mean(axis=(0, 1))
            pbar.update()
        mpv /= len(train_dset.imgpaths)
        pbar.close()
    else:
        mpv = np.array(args.mpv)

    # save training config
    mpv_json = []
    for i in range(3):
        mpv_json.append(float(mpv[i]))  # convert to json serializable type
    args_dict = vars(args)
    args_dict['mpv'] = mpv_json
    with open(os.path.join(args.result_dir, 'config.json'), mode='w') as f:
        json.dump(args_dict, f)

    # make mpv & alpha tensor
    mpv = torch.tensor(mpv.astype(np.float32).reshape(1, 3, 1, 1)).to(gpu)
    alpha = torch.tensor(args.alpha).to(gpu)

    # ================================================
    # Training Phase 1
    # ================================================
    model_cn = CompletionNetwork()
    if args.data_parallel:
        model_cn = DataParallel(model_cn)
    if args.init_model_cn != None:
        model_cn.load_state_dict(
            torch.load(args.init_model_cn, map_location='cpu'))
    if args.optimizer == 'adadelta':
        opt_cn = Adadelta(model_cn.parameters())
    else:
        opt_cn = Adam(model_cn.parameters())
    model_cn = model_cn.to(gpu)

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_1)
    while pbar.n < args.steps_1:
        for x in train_loader:
            # forward
            x = x.to(gpu)
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=gen_hole_area(
                    (args.ld_input_size, args.ld_input_size),
                    (x.shape[3], x.shape[2])),
                max_holes=args.max_holes,
            ).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input = torch.cat((x_mask, mask), dim=1)
            output = model_cn(input)
            loss = completion_network_loss(x, output, mask)

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description('phase 1 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_1 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_1',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_1',
                            'model_cn_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                # terminate
                if pbar.n >= args.steps_1:
                    break
    pbar.close()

    # ================================================
    # Training Phase 2
    # ================================================
    model_cd = ContextDiscriminator(
        local_input_shape=(3, args.ld_input_size, args.ld_input_size),
        global_input_shape=(3, args.cn_input_size, args.cn_input_size),
        arc=args.arc,
    )
    if args.data_parallel:
        model_cd = DataParallel(model_cd)
    if args.init_model_cd != None:
        model_cd.load_state_dict(
            torch.load(args.init_model_cd, map_location='cpu'))
    if args.optimizer == 'adadelta':
        opt_cd = Adadelta(model_cd.parameters())
    else:
        opt_cd = Adam(model_cd.parameters())
    model_cd = model_cd.to(gpu)
    bceloss = BCELoss()

    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_2)
    while pbar.n < args.steps_2:
        for x in train_loader:

            # fake forward
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd(
                (input_ld_fake.to(gpu), input_gd_fake.to(gpu)))
            loss_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_real = bceloss(output_real, real)

            # reduce
            loss = (loss_fake + loss_real) / 2.

            # backward
            loss.backward()
            cnt_bdivs += 1

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()
                # update progbar
                pbar.set_description('phase 2 | train loss: %.5f' % loss.cpu())
                pbar.update()
                # test
                if pbar.n % args.snaperiod_2 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_2',
                                               'step%d.png' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_2',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_2:
                    break
    pbar.close()

    # ================================================
    # Training Phase 3
    # ================================================
    # training
    cnt_bdivs = 0
    pbar = tqdm(total=args.steps_3)
    while pbar.n < args.steps_3:
        for x in train_loader:

            # forward model_cd
            x = x.to(gpu)
            hole_area_fake = gen_hole_area(
                (args.ld_input_size, args.ld_input_size),
                (x.shape[3], x.shape[2]))
            mask = gen_input_mask(
                shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                hole_size=((args.hole_min_w, args.hole_max_w),
                           (args.hole_min_h, args.hole_max_h)),
                hole_area=hole_area_fake,
                max_holes=args.max_holes,
            ).to(gpu)

            # fake forward
            fake = torch.zeros((len(x), 1)).to(gpu)
            x_mask = x - x * mask + mpv * mask
            input_cn = torch.cat((x_mask, mask), dim=1)
            output_cn = model_cn(input_cn)
            input_gd_fake = output_cn.detach()
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, input_gd_fake))
            loss_cd_fake = bceloss(output_fake, fake)

            # real forward
            hole_area_real = gen_hole_area(size=(args.ld_input_size,
                                                 args.ld_input_size),
                                           mask_size=(x.shape[3], x.shape[2]))
            real = torch.ones((len(x), 1)).to(gpu)
            input_gd_real = x
            input_ld_real = crop(input_gd_real, hole_area_real)
            output_real = model_cd((input_ld_real, input_gd_real))
            loss_cd_real = bceloss(output_real, real)

            # reduce
            loss_cd = (loss_cd_fake + loss_cd_real) * alpha / 2.

            # backward model_cd
            loss_cd.backward()

            cnt_bdivs += 1
            if cnt_bdivs >= args.bdivs:
                # optimize
                opt_cd.step()
                # clear grads
                opt_cd.zero_grad()

            # forward model_cn
            loss_cn_1 = completion_network_loss(x, output_cn, mask)
            input_gd_fake = output_cn
            input_ld_fake = crop(input_gd_fake, hole_area_fake)
            output_fake = model_cd((input_ld_fake, (input_gd_fake)))
            loss_cn_2 = bceloss(output_fake, real)

            # reduce
            loss_cn = (loss_cn_1 + alpha * loss_cn_2) / 2.

            # backward model_cn
            loss_cn.backward()

            if cnt_bdivs >= args.bdivs:
                cnt_bdivs = 0
                # optimize
                opt_cn.step()
                # clear grads
                opt_cn.zero_grad()
                # update progbar
                pbar.set_description(
                    'phase 3 | train loss (cd): %.5f (cn): %.5f' %
                    (loss_cd.cpu(), loss_cn.cpu()))
                pbar.update()
                # test
                if pbar.n % args.snaperiod_3 == 0:
                    with torch.no_grad():
                        x = sample_random_batch(
                            test_dset,
                            batch_size=args.num_test_completions).to(gpu)
                        mask = gen_input_mask(
                            shape=(x.shape[0], 1, x.shape[2], x.shape[3]),
                            hole_size=((args.hole_min_w, args.hole_max_w),
                                       (args.hole_min_h, args.hole_max_h)),
                            hole_area=gen_hole_area(
                                (args.ld_input_size, args.ld_input_size),
                                (x.shape[3], x.shape[2])),
                            max_holes=args.max_holes,
                        ).to(gpu)
                        x_mask = x - x * mask + mpv * mask
                        input = torch.cat((x_mask, mask), dim=1)
                        output = model_cn(input)
                        completed = poisson_blend(x, output, mask)
                        imgs = torch.cat(
                            (x.cpu(), x_mask.cpu(), completed.cpu()), dim=0)
                        imgpath = os.path.join(args.result_dir, 'phase_3',
                                               'step%d.png' % pbar.n)
                        model_cn_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cn_step%d' % pbar.n)
                        model_cd_path = os.path.join(
                            args.result_dir, 'phase_3',
                            'model_cd_step%d' % pbar.n)
                        save_image(imgs, imgpath, nrow=len(x))
                        if args.data_parallel:
                            torch.save(model_cn.module.state_dict(),
                                       model_cn_path)
                            torch.save(model_cd.module.state_dict(),
                                       model_cd_path)
                        else:
                            torch.save(model_cn.state_dict(), model_cn_path)
                            torch.save(model_cd.state_dict(), model_cd_path)
                # terminate
                if pbar.n >= args.steps_3:
                    break
    pbar.close()
Beispiel #5
0
    def train(self, device, args):
        # data_dir, dataset_metadata, save_dir,
        # 	  batch_size=64, num_train_epochs=300,
        # 	  device='cuda', init_lr=5e-4, logging_steps=50,
        # 	  label_key='edema_severity', loss_method='CrossEntropyLoss'):
        '''
        Create a logger for logging model training
        '''
        logger = logging.getLogger(__name__)
        '''
        Create an instance of traning data loader
        '''
        print('***** Instantiate a data loader *****')
        dataset = build_training_dataset(
            data_dir=args.data_dir,
            img_size=self.img_size,
            dataset_metadata=args.dataset_metadata,
            label_key=args.label_key)
        data_loader = DataLoader(dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=8,
                                 pin_memory=True)
        print(f'Total number of training images: {len(dataset)}')
        ''' 
        Create an instance of loss
        '''
        print('***** Instantiate the training loss *****')
        if args.loss_method == 'CrossEntropyLoss':
            loss_criterion = CrossEntropyLoss().to(device)
        elif args.loss_method == 'BCEWithLogitsLoss':
            loss_criterion = BCEWithLogitsLoss().to(device)
        elif args.loss_method == 'BCELoss':
            loss_criterion = BCELoss().to(device)
        '''
        Create an instance of optimizer and learning rate scheduler
        '''
        print('***** Instantiate an optimizer *****')
        optimizer = optim.Adam(self.model.parameters(), lr=args.init_lr)
        '''
        Train the model
        '''
        print('***** Train the model *****')
        self.model = self.model.to(device)
        self.model.train()
        train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
        for epoch in train_iterator:
            start_time = time.time()
            epoch_loss = 0
            epoch_iterator = tqdm(data_loader, desc="Iteration")
            for i, batch in enumerate(epoch_iterator, 0):
                # Parse the batch
                images, labels, image_ids = batch
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                # Zero out the parameter gradients
                optimizer.zero_grad()

                # Forward + backward + optimize
                outputs = self.model(images)
                pred_logits = outputs[-1]
                if args.loss_method == 'BCEWithLogitsLoss':
                    labels = torch.reshape(labels, pred_logits.size())

                # pred_logits[labels<0] = 0
                # labels[labels<0] = 0.5
                #labels = labels.type_as(outputs)
                loss = loss_criterion(pred_logits.float(), labels.float())

                loss.backward()
                optimizer.step()

                # Record training statistics
                epoch_loss += loss.item()
                print("Epoch loss %s" % str(epoch_loss))

                if not loss.item() > 0:
                    logger.info(f"loss: {loss.item()}")
                    logger.info(f"pred_logits: {pred_logits}")
                    logger.info(f"labels: {labels}")
            self.model.save_pretrained(args.save_dir, epoch=epoch + 1)
            interval = time.time() - start_time

            print(f'Epoch {epoch+1} finished! Epoch loss: {epoch_loss:.5f}')

            logger.info(f"  Epoch {epoch+1} loss = {epoch_loss:.5f}")
            logger.info(f"  Epoch {epoch+1} took {interval:.3f} s")

        return
            Sigmoid(),
        )
        normal_init(self)

    def forward(self, input, label):
        x = torch.cat([input, label], 1)
        x = self.conv(x)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc(x)

        return x


NetG = Generator()
NetD = Discriminator()
BCE_LOSS = BCELoss()
G_optimizer = Adam(NetG.parameters(),
                   lr=CONFIG["LEARNING_RATE"],
                   betas=(0.5, 0.999))
D_optimizer = Adam(NetD.parameters(),
                   lr=CONFIG["LEARNING_RATE"],
                   betas=(0.5, 0.999))

if CONFIG["GPU_NUMS"] > 0:
    NetG = NetG.cuda()
    NetD = NetD.cuda()
    BCE_LOSS = BCE_LOSS.cuda()

transform = Compose([ToTensor()])
train_loader = torch.utils.data.DataLoader(MNISTDataSetForPytorch(
    train=True, transform=transform),
Beispiel #7
0
def augmented_cross_entropy(s_hat, log_r_hat, t0_hat, t1_hat, y_true, r_true,
                            t0_true, t1_true):
    s_hat = 1.0 / (1.0 + torch.exp(log_r_hat))
    s_true = 1.0 / (1.0 + r_true)

    return BCELoss()(s_hat, s_true)
Beispiel #8
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Examples::
        from transformers import BertTokenizer, BertForSequenceClassification
        import torch
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]
        """

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        pooled_output = outputs[1]        # use [CLS] pooled output

        pooled_output = self.dropout(pooled_output)
        logits = self.regressor(pooled_output)
        logits = self.sigmoid(logits)
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            loss_fct = BCELoss()
            labels = labels.to(torch.float)
            loss = loss_fct(logits.view(-1), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs
def main(cfg):
    """Runs main training procedure."""

    # fix random seeds for reproducibility
    seed_everything(seed=cfg['seed'])

    # neptune logging
    neptune.init(project_qualified_name=cfg['neptune_project_name'],
                 api_token=cfg['neptune_api_token'])

    neptune.create_experiment(name=cfg['neptune_experiment'], params=cfg)

    print('Preparing model and data...')
    print('Using SMP version:', smp.__version__)

    num_classes = 1 if len(cfg['classes']) == 1 else (len(cfg['classes']) + 1)
    activation = 'sigmoid' if num_classes == 1 else 'softmax2d'
    background = False if cfg['ignore_channels'] else True
    binary = True if num_classes == 1 else False
    softmax = False if num_classes == 1 else True
    sigmoid = True if num_classes == 1 else False

    aux_params = dict(
        pooling=cfg['pooling'],  # one of 'avg', 'max'
        dropout=cfg['dropout'],  # dropout ratio, default is None
        activation='sigmoid',  # activation function, default is None
        classes=num_classes)  # define number of output labels

    # configure model
    models = {
        'unet':
        Unet(encoder_name=cfg['encoder_name'],
             encoder_weights=cfg['encoder_weights'],
             decoder_use_batchnorm=cfg['use_batchnorm'],
             classes=num_classes,
             activation=activation,
             aux_params=aux_params),
        'pspnet':
        PSPNet(encoder_name=cfg['encoder_name'],
               encoder_weights=cfg['encoder_weights'],
               classes=num_classes,
               activation=activation,
               aux_params=aux_params),
        'pan':
        PAN(encoder_name=cfg['encoder_name'],
            encoder_weights=cfg['encoder_weights'],
            classes=num_classes,
            activation=activation,
            aux_params=aux_params),
        'deeplabv3plus':
        DeepLabV3Plus(encoder_name=cfg['encoder_name'],
                      encoder_weights=cfg['encoder_weights'],
                      classes=num_classes,
                      activation=activation,
                      aux_params=aux_params)
    }

    assert cfg['architecture'] in models.keys()
    model = models[cfg['architecture']]

    # configure loss
    losses = {
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=softmax,
                 sigmoid=sigmoid,
                 batch=cfg['combine']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=softmax,
                            sigmoid=sigmoid,
                            batch=cfg['combine'])
    }

    assert cfg['loss'] in losses.keys()
    loss = losses[cfg['loss']]

    # configure optimizer
    optimizers = {
        'adam': Adam([dict(params=model.parameters(), lr=cfg['lr'])]),
        'adamw': AdamW([dict(params=model.parameters(), lr=cfg['lr'])]),
        'rmsprop': RMSprop([dict(params=model.parameters(), lr=cfg['lr'])])
    }

    assert cfg['optimizer'] in optimizers.keys()
    optimizer = optimizers[cfg['optimizer']]

    # configure metrics
    metrics = {
        'dice_score':
        DiceMetric(include_background=background, reduction='mean'),
        'dice_smp':
        Fscore(threshold=cfg['rounding'],
               ignore_channels=cfg['ignore_channels']),
        'iou_smp':
        IoU(threshold=cfg['rounding'], ignore_channels=cfg['ignore_channels']),
        'generalized_dice':
        GeneralizedDiceLoss(include_background=background,
                            softmax=softmax,
                            sigmoid=sigmoid,
                            batch=cfg['combine']),
        'dice_loss':
        DiceLoss(include_background=background,
                 softmax=softmax,
                 sigmoid=sigmoid,
                 batch=cfg['combine']),
        'cross_entropy':
        BCELoss(reduction='mean'),
        'accuracy':
        Accuracy(ignore_channels=cfg['ignore_channels'])
    }

    assert all(m['name'] in metrics.keys() for m in cfg['metrics'])
    metrics = [(metrics[m['name']], m['name'], m['type'])
               for m in cfg['metrics']]  # tuple of (metric, name, type)

    # TODO: Fix metric names

    # configure scheduler
    schedulers = {
        'steplr':
        StepLR(optimizer, step_size=cfg['step_size'], gamma=0.5),
        'cosine':
        CosineAnnealingLR(optimizer,
                          cfg['epochs'],
                          eta_min=cfg['eta_min'],
                          last_epoch=-1)
    }

    assert cfg['scheduler'] in schedulers.keys()
    scheduler = schedulers[cfg['scheduler']]

    # configure augmentations
    train_transform = load_train_transform(transform_type=cfg['transform'],
                                           patch_size=cfg['patch_size_train'])
    valid_transform = load_valid_transform(
        patch_size=cfg['patch_size_valid'])  # manually selected patch size

    train_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                    classes=cfg['classes'],
                                    transform=train_transform,
                                    normalize=cfg['normalize'],
                                    ink_filters=cfg['ink_filters'])

    valid_dataset = ArtifactDataset(df_path=cfg['valid_data'],
                                    classes=cfg['classes'],
                                    transform=valid_transform,
                                    normalize=cfg['normalize'],
                                    ink_filters=cfg['ink_filters'])

    test_dataset = ArtifactDataset(df_path=cfg['test_data'],
                                   classes=cfg['classes'],
                                   transform=valid_transform,
                                   normalize=cfg['normalize'],
                                   ink_filters=cfg['ink_filters'])

    # load pre-sampled patch arrays
    train_image, train_mask = train_dataset[0]
    valid_image, valid_mask = valid_dataset[0]
    print('Shape of image patch', train_image.shape)
    print('Shape of mask patch', train_mask.shape)
    print('Train dataset shape:', len(train_dataset))
    print('Valid dataset shape:', len(valid_dataset))
    assert train_image.shape[1] == cfg[
        'patch_size_train'] and train_image.shape[2] == cfg['patch_size_train']
    assert valid_image.shape[1] == cfg[
        'patch_size_valid'] and valid_image.shape[2] == cfg['patch_size_valid']

    # save intermediate augmentations
    if cfg['eval_dir']:
        default_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                          classes=cfg['classes'],
                                          transform=None,
                                          normalize=None,
                                          ink_filters=cfg['ink_filters'])

        transform_dataset = ArtifactDataset(df_path=cfg['train_data'],
                                            classes=cfg['classes'],
                                            transform=train_transform,
                                            normalize=None,
                                            ink_filters=cfg['ink_filters'])

        for idx in range(0, min(500, len(train_dataset)), 10):
            image_input, image_mask = default_dataset[idx]
            image_input = image_input.transpose((1, 2, 0)).astype(np.uint8)

            image_mask = image_mask.transpose(1, 2, 0)
            image_mask = np.argmax(
                image_mask, axis=2) if not binary else image_mask.squeeze()
            image_mask = image_mask.astype(np.uint8)

            image_transform, _ = transform_dataset[idx]
            image_transform = image_transform.transpose(
                (1, 2, 0)).astype(np.uint8)

            idx_str = str(idx).zfill(3)
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}a_image_input.png'),
                              image_input,
                              check_contrast=False)
            plt.imsave(os.path.join(cfg['eval_dir'],
                                    f'{idx_str}b_image_mask.png'),
                       image_mask,
                       vmin=0,
                       vmax=6,
                       cmap='Spectral')
            skimage.io.imsave(os.path.join(cfg['eval_dir'],
                                           f'{idx_str}c_image_transform.png'),
                              image_transform,
                              check_contrast=False)

        del transform_dataset

    # update process
    print('Starting training...')
    print('Available GPUs for training:', torch.cuda.device_count())

    # pytorch module wrapper
    class DataParallelModule(torch.nn.DataParallel):
        def __getattr__(self, name):
            try:
                return super().__getattr__(name)
            except AttributeError:
                return getattr(self.module, name)

    # data parallel training
    if torch.cuda.device_count() > 1:
        model = DataParallelModule(model)

    train_loader = DataLoader(train_dataset,
                              batch_size=cfg['batch_size'],
                              num_workers=cfg['workers'],
                              shuffle=True)

    valid_loader = DataLoader(valid_dataset,
                              batch_size=int(cfg['batch_size'] / 4),
                              num_workers=cfg['workers'],
                              shuffle=False)

    test_loader = DataLoader(test_dataset,
                             batch_size=int(cfg['batch_size'] / 4),
                             num_workers=cfg['workers'],
                             shuffle=False)

    trainer = Trainer(model=model,
                      device=cfg['device'],
                      save_checkpoints=cfg['save_checkpoints'],
                      checkpoint_dir=cfg['checkpoint_dir'],
                      checkpoint_name=cfg['checkpoint_name'])

    trainer.compile(optimizer=optimizer,
                    loss=loss,
                    metrics=metrics,
                    num_classes=num_classes)

    trainer.fit(train_loader,
                valid_loader,
                epochs=cfg['epochs'],
                scheduler=scheduler,
                verbose=cfg['verbose'],
                loss_weight=cfg['loss_weight'],
                test_loader=test_loader,
                binary=binary)

    # validation inference
    model.load_state_dict(
        torch.load(os.path.join(cfg['checkpoint_dir'],
                                cfg['checkpoint_name'])))
    model.to(cfg['device'])
    model.eval()

    # save best checkpoint to neptune
    neptune.log_artifact(
        os.path.join(cfg['checkpoint_dir'], cfg['checkpoint_name']))

    # setup directory to save plots
    if os.path.isdir(cfg['plot_dir_valid']):
        shutil.rmtree(cfg['plot_dir_valid'])
    os.makedirs(cfg['plot_dir_valid'], exist_ok=True)

    # valid dataset without transformations and normalization for image visualization
    valid_dataset_vis = ArtifactDataset(df_path=cfg['valid_data'],
                                        classes=cfg['classes'],
                                        ink_filters=cfg['ink_filters'])

    # keep track of valid masks
    valid_preds = []
    valid_masks = []

    if cfg['save_checkpoints']:
        print('Predicting valid patches...')
        for n in range(len(valid_dataset)):
            image_vis = valid_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose(1, 2, 0)
            image, gt_mask = valid_dataset[n]
            gt_mask = gt_mask.transpose(1, 2, 0)
            gt_mask = np.argmax(gt_mask,
                                axis=2) if not binary else gt_mask.squeeze()
            gt_mask = gt_mask.astype(np.uint8)
            valid_masks.append(gt_mask)

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = model.predict(x_tensor)
            pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round()
            pr_mask = pr_mask.transpose(1, 2, 0)
            pr_mask = np.argmax(pr_mask,
                                axis=2) if not binary else pr_mask.squeeze()
            pr_mask = pr_mask.astype(np.uint8)
            valid_preds.append(pr_mask)

            save_predictions(out_path=cfg['plot_dir_valid'],
                             index=n + 1,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask)

    del train_dataset, valid_dataset
    del train_loader, valid_loader

    # calculate dice per class
    valid_masks = np.stack(valid_masks, axis=0)
    valid_masks = valid_masks.flatten()
    valid_preds = np.stack(valid_preds, axis=0)
    valid_preds = valid_preds.flatten()
    dice_score = f1_score(y_true=valid_masks, y_pred=valid_preds, average=None)
    neptune.log_text('valid_dice_class', str(dice_score))
    print('Valid dice score (class):', str(dice_score))

    if cfg['evaluate_test_set']:
        print('Predicting test patches...')

        # setup directory to save plots
        if os.path.isdir(cfg['plot_dir_test']):
            shutil.rmtree(cfg['plot_dir_test'])
        os.makedirs(cfg['plot_dir_test'], exist_ok=True)

        # test dataset without transformations and normalization for image visualization
        test_dataset_vis = ArtifactDataset(df_path=cfg['test_data'],
                                           classes=cfg['classes'],
                                           ink_filters=cfg['ink_filters'])

        # keep track of test masks
        test_masks = []
        test_preds = []

        for n in range(len(test_dataset)):
            image_vis = test_dataset_vis[n][0].astype('uint8')
            image_vis = image_vis.transpose(1, 2, 0)
            image, gt_mask = test_dataset[n]
            gt_mask = gt_mask.transpose(1, 2, 0)
            gt_mask = np.argmax(gt_mask,
                                axis=2) if not binary else gt_mask.squeeze()
            gt_mask = gt_mask.astype(np.uint8)
            test_masks.append(gt_mask)

            x_tensor = torch.from_numpy(image).to(cfg['device']).unsqueeze(0)
            pr_mask, _ = model.predict(x_tensor)
            pr_mask = pr_mask.squeeze(axis=0).cpu().numpy().round()
            pr_mask = pr_mask.transpose(1, 2, 0)
            pr_mask = np.argmax(pr_mask,
                                axis=2) if not binary else pr_mask.squeeze()
            pr_mask = pr_mask.astype(np.uint8)
            test_preds.append(pr_mask)

            save_predictions(out_path=cfg['plot_dir_test'],
                             index=n + 1,
                             image=image_vis,
                             ground_truth_mask=gt_mask,
                             predicted_mask=pr_mask)

            # calculate dice per class
            test_masks = np.stack(test_masks, axis=0)
            test_masks = test_masks.flatten()
            test_preds = np.stack(test_preds, axis=0)
            test_preds = test_preds.flatten()
            dice_score = f1_score(y_true=test_masks,
                                  y_pred=test_preds,
                                  average=None)
            neptune.log_text('test_dice_class', str({dice_score}))
            print('Test dice score (class):', str(dice_score))

    # end of training process
    print('Finished training!')
Beispiel #10
0
 def _define_loss(self):
     self.loss_functions = {"Generator": BCELoss(), "Adversariat": BCELoss(), "L1": L1Loss()}
Beispiel #11
0
def main(args):
    """
    :param n_trials: int specifying number of random train/test splits to use
    :param test_set_size: float in range [0, 1] specifying fraction of dataset to use as test set
    """

    df = pd.read_csv('data/covid_multitask_pIC50.smi')
    smiles_list = df['SMILES'].values
    y = df[[
        'acry_class', 'chloro_class', 'rest_class', 'acry_reg', 'chloro_reg',
        'rest_reg'
    ]].to_numpy()

    n_tasks = y.shape[1]
    class_inds = [0, 1, 2]
    reg_inds = [3, 4, 5]
    X = [Chem.MolFromSmiles(m) for m in smiles_list]
    descs = np.array([generate_descriptors(m) for m in smiles_list])
    # Initialise featurisers
    atom_featurizer = CanonicalAtomFeaturizer()
    bond_featurizer = CanonicalBondFeaturizer()

    e_feats = bond_featurizer.feat_size('e')
    n_feats = atom_featurizer.feat_size('h')
    print('Number of features: ', n_feats)

    X = np.array([
        mol_to_bigraph(m,
                       node_featurizer=atom_featurizer,
                       edge_featurizer=bond_featurizer) for m in X
    ])

    r2_list = []
    rmse_list = []
    roc_list = []
    prc_list = []

    for i in range(args.n_trials):
        writer = SummaryWriter('runs/' + args.savename + '/run_' + str(i))

        #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_set_size, random_state=i+5)
        if args.test:
            X_train_acry, X_test_acry, \
            descs_train_acry, descs_test_acry, \
            y_train_acry, y_test_acry = train_test_split(X[~np.isnan(y[:,0])], descs[~np.isnan(y[:, 0])],
                                                         y[~np.isnan(y[:,0])], stratify=y[:,0][~np.isnan(y[:,0])],
                                                         test_size=args.test_set_size, shuffle=True, random_state=i+5)
            X_train_chloro, X_test_chloro, \
            descs_train_chloro, descs_test_chloro, \
            y_train_chloro, y_test_chloro = train_test_split(X[~np.isnan(y[:,1])], descs[~np.isnan(y[:, 1])],
                                                             y[~np.isnan(y[:,1])], stratify=y[:,1][~np.isnan(y[:,1])],
                                                              test_size=args.test_set_size, shuffle=True, random_state=i+5)
            X_train_rest, X_test_rest, \
            descs_train_rest, descs_test_rest, \
            y_train_rest, y_test_rest = train_test_split(X[~np.isnan(y[:,2])], descs[~np.isnan(y[:, 2])],
                                                         y[~np.isnan(y[:,2])], stratify=y[:,2][~np.isnan(y[:,2])],
                                                          test_size=args.test_set_size, shuffle=True, random_state=i+5)

            X_train = np.concatenate(
                [X_train_acry, X_train_chloro, X_train_rest])
            X_test = np.concatenate([X_test_acry, X_test_chloro, X_test_rest])
            desc_train = np.concatenate(
                [descs_train_acry, descs_train_chloro, descs_train_rest])
            desc_test = np.concatenate(
                [descs_test_acry, descs_test_chloro, descs_test_rest])
            y_train = np.concatenate(
                [y_train_acry, y_train_chloro, y_train_rest])
            y_test = np.concatenate([y_test_acry, y_test_chloro, y_test_rest])

            desc_train = torch.Tensor(desc_train)
            desc_test = torch.Tensor(desc_test)
            y_train = torch.Tensor(y_train)
            y_test = torch.Tensor(y_test)

            train_data = list(zip(X_train, desc_train, y_train))
            test_data = list(zip(X_test, desc_test, y_test))

            train_loader = DataLoader(train_data,
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=collate,
                                      drop_last=False)
            test_loader = DataLoader(test_data,
                                     batch_size=32,
                                     shuffle=True,
                                     collate_fn=collate,
                                     drop_last=False)
        else:
            train_data = list(zip(X, descs, y))
            train_loader = DataLoader(train_data,
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=collate,
                                      drop_last=False)

        process = Net(class_inds, reg_inds)
        process = process.to(device)

        mpnn_net = CustomMPNNPredictor(node_in_feats=n_feats,
                                       edge_in_feats=e_feats,
                                       node_out_feats=128,
                                       n_tasks=n_tasks)
        mpnn_net = mpnn_net.to(device)

        reg_loss_fn = MSELoss()
        class_loss_fn = BCELoss()

        optimizer = torch.optim.Adam(mpnn_net.parameters(), lr=0.001)

        for epoch in range(1, args.n_epochs + 1):
            epoch_loss = 0
            preds = []
            labs = []
            mpnn_net.train()
            n = 0
            for i, (bg, dcs, labels) in enumerate(train_loader):
                dcs = dcs.to(device)
                labels = labels.to(device)
                atom_feats = bg.ndata.pop('h').to(device)
                bond_feats = bg.edata.pop('e').to(device)
                y_pred = mpnn_net(bg, atom_feats, bond_feats, dcs)
                y_pred = process(y_pred)
                if args.debug:
                    print('y_pred: {}'.format(y_pred))
                    print('label: {}'.format(labels))
                loss = torch.tensor(0)
                loss = loss.to(device)
                for ind in reg_inds:
                    if len(labels[:, ind][~torch.isnan(labels[:, ind])]) == 0:
                        continue
                    loss = loss + reg_loss_fn(
                        y_pred[:, ind][~torch.isnan(labels[:, ind])],
                        labels[:, ind][~torch.isnan(labels[:, ind])])
                if args.debug:
                    print('reg loss: {}'.format(loss))

                for ind in class_inds:
                    if len(labels[:, ind][~torch.isnan(labels[:, ind])]) == 0:
                        continue
                    loss = loss + class_loss_fn(
                        y_pred[:, ind][~torch.isnan(labels[:, ind])],
                        labels[:, ind][~torch.isnan(labels[:, ind])])

                optimizer.zero_grad()
                loss.backward()
                if args.debug:
                    print('class + reg loss: {}'.format(loss))
                    print('y_pred: {}'.format(y_pred))
                    print('label: {}'.format(labels))
                # for p in mpnn_net.parameters():
                #     print(p.grad)
                optimizer.step()
                if args.debug:
                    n += 1
                    if n == 10:
                        raise Exception
                epoch_loss += loss.detach().item()

                labels = labels.cpu().numpy()
                y_pred = y_pred.detach().cpu().numpy()

                # store labels and preds
                preds.append(y_pred)
                labs.append(labels)

            labs = np.concatenate(labs, axis=0)
            preds = np.concatenate(preds, axis=0)
            rmses = []
            r2s = []
            rocs = []
            prcs = []
            for ind in reg_inds:
                rmse = np.sqrt(
                    mean_squared_error(labs[:, ind][~np.isnan(labs[:, ind])],
                                       preds[:, ind][~np.isnan(labs[:, ind])]))
                r2 = r2_score(labs[:, ind][~np.isnan(labs[:, ind])],
                              preds[:, ind][~np.isnan(labs[:, ind])])
                rmses.append(rmse)
                r2s.append(r2)

            for ind in class_inds:
                roc = roc_auc_score(labs[:, ind][~np.isnan(labs[:, ind])],
                                    preds[:, ind][~np.isnan(labs[:, ind])])
                precision, recall, thresholds = precision_recall_curve(
                    labs[:, ind][~np.isnan(labs[:, ind])],
                    preds[:, ind][~np.isnan(labs[:, ind])])
                prc = auc(recall, precision)
                rocs.append(roc)
                prcs.append(prc)

            writer.add_scalar('LOSS/train', epoch_loss, epoch)
            writer.add_scalar('train/acry_rocauc', rocs[0], epoch)
            writer.add_scalar('train/acry_prcauc', prcs[0], epoch)
            writer.add_scalar('train/chloro_rocauc', rocs[1], epoch)
            writer.add_scalar('train/chloro_prcauc', prcs[1], epoch)
            writer.add_scalar('train/rest_rocauc', rocs[2], epoch)
            writer.add_scalar('train/rest_prcauc', prcs[2], epoch)

            writer.add_scalar('train/acry_rmse', rmses[0], epoch)
            writer.add_scalar('train/acry_r2', r2s[0], epoch)
            writer.add_scalar('train/chloro_rmse', rmses[1], epoch)
            writer.add_scalar('train/chloro_r2', r2s[1], epoch)
            writer.add_scalar('train/rest_rmse', rmses[2], epoch)
            writer.add_scalar('train/rest_r2', r2s[2], epoch)

            if epoch % 20 == 0:
                print(f"\nepoch: {epoch}, "
                      f"LOSS: {epoch_loss:.3f}"
                      f"\n acry ROC-AUC: {rocs[0]:.3f}, "
                      f"acry PRC-AUC: {prcs[0]:.3f}"
                      f"\n chloro ROC-AUC: {rocs[1]:.3f}, "
                      f"chloro PRC-AUC: {prcs[1]:.3f}"
                      f"\n rest ROC-AUC: {rocs[2]:.3f}, "
                      f"rest PRC-AUC: {prcs[2]:.3f}"
                      f"\n acry R2: {r2s[0]:.3f}, "
                      f"acry RMSE: {rmses[0]:.3f}"
                      f"\n chloro R2: {r2s[1]:.3f}, "
                      f"chloro RMSE: {rmses[1]:.3f}"
                      f"\n rest R2: {r2s[2]:.3f}, "
                      f"rest RMSE: {rmses[2]:.3f}")
                try:
                    torch.save(
                        mpnn_net.state_dict(), 'models/' + args.savename +
                        '/model_epoch_' + str(epoch) + '.pt')
                except FileNotFoundError:
                    cmd = 'mkdir models/' + args.savename
                    os.system(cmd)
                    torch.save(
                        mpnn_net.state_dict(), 'models/' + args.savename +
                        '/model_epoch_' + str(epoch) + '.pt')
            if args.test:
                # Evaluate
                mpnn_net.eval()
                preds = []
                labs = []
                for i, (bg, dcs, labels) in enumerate(test_loader):
                    dcs = dcs.to(device)
                    labels = labels.to(device)
                    atom_feats = bg.ndata.pop('h').to(device)
                    bond_feats = bg.edata.pop('e').to(device)
                    y_pred = mpnn_net(bg, atom_feats, bond_feats, dcs)
                    y_pred = process(y_pred)

                    labels = labels.cpu().numpy()
                    y_pred = y_pred.detach().cpu().numpy()

                    preds.append(y_pred)
                    labs.append(labels)

                labs = np.concatenate(labs, axis=0)
                preds = np.concatenate(preds, axis=0)
                rmses = []
                r2s = []
                rocs = []
                prcs = []
                for ind in reg_inds:

                    rmse = np.sqrt(
                        mean_squared_error(
                            labs[:, ind][~np.isnan(labs[:, ind])],
                            preds[:, ind][~np.isnan(labs[:, ind])]))

                    r2 = r2_score(labs[:, ind][~np.isnan(labs[:, ind])],
                                  preds[:, ind][~np.isnan(labs[:, ind])])
                    rmses.append(rmse)
                    r2s.append(r2)
                for ind in class_inds:
                    roc = roc_auc_score(labs[:, ind][~np.isnan(labs[:, ind])],
                                        preds[:, ind][~np.isnan(labs[:, ind])])
                    precision, recall, thresholds = precision_recall_curve(
                        labs[:, ind][~np.isnan(labs[:, ind])],
                        preds[:, ind][~np.isnan(labs[:, ind])])
                    prc = auc(recall, precision)
                    rocs.append(roc)
                    prcs.append(prc)
                writer.add_scalar('test/acry_rocauc', rocs[0], epoch)
                writer.add_scalar('test/acry_prcauc', prcs[0], epoch)
                writer.add_scalar('test/chloro_rocauc', rocs[1], epoch)
                writer.add_scalar('test/chloro_prcauc', prcs[1], epoch)
                writer.add_scalar('test/rest_rocauc', rocs[2], epoch)
                writer.add_scalar('test/rest_prcauc', prcs[2], epoch)

                writer.add_scalar('test/acry_rmse', rmses[0], epoch)
                writer.add_scalar('test/acry_r2', r2s[0], epoch)
                writer.add_scalar('test/chloro_rmse', rmses[1], epoch)
                writer.add_scalar('test/chloro_r2', r2s[1], epoch)
                writer.add_scalar('test/rest_rmse', rmses[2], epoch)
                writer.add_scalar('test/rest_r2', r2s[2], epoch)
                if epoch == (args.n_epochs):
                    print(
                        f"\n======================== TEST ========================"
                        f"\n acry ROC-AUC: {rocs[0]:.3f}, "
                        f"acry PRC-AUC: {prcs[0]:.3f}"
                        f"\n chloro ROC-AUC: {rocs[1]:.3f}, "
                        f"chloro PRC-AUC: {prcs[1]:.3f}"
                        f"\n rest ROC-AUC: {rocs[2]:.3f}, "
                        f"rest PRC-AUC: {prcs[2]:.3f}"
                        f"\n acry R2: {r2s[0]:.3f}, "
                        f"acry RMSE: {rmses[0]:.3f}"
                        f"\n chloro R2: {r2s[1]:.3f}, "
                        f"chloro RMSE: {rmses[1]:.3f}"
                        f"\n rest R2: {r2s[2]:.3f}, "
                        f"rest RMSE: {rmses[2]:.3f}")
                    roc_list.append(rocs)
                    prc_list.append(prcs)
                    r2_list.append(r2s)
                    rmse_list.append(rmses)

    roc_list = np.array(roc_list).T
    prc_list = np.array(prc_list).T
    r2_list = np.array(r2_list).T
    rmse_list = np.array(rmse_list).T
    print("\n ACRY")
    print("R^2: {:.4f} +- {:.4f}".format(
        np.mean(r2_list[0]),
        np.std(r2_list[0]) / np.sqrt(len(r2_list[0]))))
    print("RMSE: {:.4f} +- {:.4f}".format(
        np.mean(rmse_list[0]),
        np.std(rmse_list[0]) / np.sqrt(len(rmse_list[0]))))
    print("ROC-AUC: {:.3f} +- {:.3f}".format(
        np.mean(roc_list[0]),
        np.std(roc_list[0]) / np.sqrt(len(roc_list[0]))))
    print("PRC-AUC: {:.3f} +- {:.3f}".format(
        np.mean(prc_list[0]),
        np.std(prc_list[0]) / np.sqrt(len(prc_list[0]))))
    print("\n CHLORO")
    print("R^2: {:.4f} +- {:.4f}".format(
        np.mean(r2_list[1]),
        np.std(r2_list[1]) / np.sqrt(len(r2_list[1]))))
    print("RMSE: {:.4f} +- {:.4f}".format(
        np.mean(rmse_list[1]),
        np.std(rmse_list[1]) / np.sqrt(len(rmse_list[1]))))
    print("ROC-AUC: {:.3f} +- {:.3f}".format(
        np.mean(roc_list[1]),
        np.std(roc_list[1]) / np.sqrt(len(roc_list[1]))))
    print("PRC-AUC: {:.3f} +- {:.3f}".format(
        np.mean(prc_list[1]),
        np.std(prc_list[1]) / np.sqrt(len(prc_list[1]))))
    print("\n REST")
    print("R^2: {:.4f} +- {:.4f}".format(
        np.mean(r2_list[2]),
        np.std(r2_list[2]) / np.sqrt(len(r2_list[2]))))
    print("RMSE: {:.4f} +- {:.4f}".format(
        np.mean(rmse_list[2]),
        np.std(rmse_list[2]) / np.sqrt(len(rmse_list[2]))))
    print("ROC-AUC: {:.3f} +- {:.3f}".format(
        np.mean(roc_list[2]),
        np.std(roc_list[2]) / np.sqrt(len(roc_list[2]))))
    print("PRC-AUC: {:.3f} +- {:.3f}".format(
        np.mean(prc_list[2]),
        np.std(prc_list[2]) / np.sqrt(len(prc_list[2]))))
Beispiel #12
0
    def train2(self, start_epoch=1):

        logger.info('New training')

        LEARNING_RATE = 2e-5  # 2e-6 does not work (?)

        optimizer = Adam(self.model.parameters(), lr=LEARNING_RATE)

        loss_func = BCELoss()

        for epoch_num in range(start_epoch, start_epoch + self.epochs):
            self.model.train()
            train_loss = 0

            # print(f'Epoch: {epoch_num}/{start_epoch + self.epochs}')

            # for step, batch in enumerate(tqdm_notebook(train_dataloader, desc="Iteration")):
            for batch_idx, batch_data in enumerate(self.progress(self.train_data_loader, desc=f'Training Epoch: {epoch_num}/{self.epochs}')):
                batch_data = tuple(t.to(self.device) for t in batch_data)
                xs = self.train_data_loader.get_x_from_batch(batch_data)
                ys = self.train_data_loader.get_y_from_batch(batch_data)

                # Pass to model
                ys_pred = self.model(*xs)

                batch_loss = loss_func(
                    self.model_output_to_loss_input(ys_pred),
                    self.data_loader_to_loss_input(ys),
                )

                train_loss += batch_loss.item()

                self.model.zero_grad()
                batch_loss.backward()
                optimizer.step()

                # clear_output(wait=True)

            # print("\r" + "{0}/{1} loss: {2} ".format(step_num, len(self.train_data_loader) / self.data_helper.train_batch_size,
            #                                          train_loss / (step_num + 1)))

            # Loss
            train_loss = train_loss / (batch_idx + 1)
            logger.info(f'Loss/Train after {epoch_num} epochs: {train_loss}')

            # test if requested or last epoch
            if epoch_num == self.epochs or (
                    self.test_every_n_epoch > 0 and (epoch_num % self.test_every_n_epoch) == 0):
                report, outputs = self.test2(self.test_data_loader)

                self.reports[epoch_num] = report
                self.outputs[epoch_num] = outputs

                for avg in ['micro avg', 'macro avg']:
                    if avg in report:
                        for metric in ['precision', 'recall', 'f1-score']:
                            if metric in report[avg]:
                                avg_label = avg.replace(' ', '_')

                                logger.info(f'{avg_label}-{metric}/Test: {report[avg][metric]}')

                                if self.tensorboard_writer:
                                    self.tensorboard_writer.add_scalar(f'{avg_label}-{metric}/Test',
                                                                       report[avg][metric], epoch_num)


            logger.info(str(torch.cuda.memory_allocated(self.device) / 1000000) + 'M')

        logger.info('Debug training completed')

        return self.reports
Beispiel #13
0
    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                subject_ids=None,
                batch_size=None,
                obj_labels=None,
                sub_train=False,
                obj_train=False):

        outputs_1 = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        sequence_output = outputs_1[0]
        sequence_output = self.dropout(sequence_output)
        if sub_train == True:
            logits = self.classifier(sequence_output)
            outputs = (logits,
                       )  # add hidden states and attention if they are here
            loss_fct = BCELoss(reduction='none')
            # loss_fct = BCEWithLogitsLoss(reduction='none')
            loss_sig = nn.Sigmoid()
            # Only keep active parts of the loss
            active_logits = logits.view(-1, self.num_labels)
            active_logits = loss_sig(active_logits)
            active_logits = active_logits**2
            if labels is not None:
                active_labels = labels.view(-1, self.num_labels).float()
                loss = loss_fct(active_logits, active_labels)
                # loss = loss.view(-1,sequence_output.size()[1],2)
                loss = loss.view(batch_size, -1, 2)
                loss = torch.mean(loss, 2)
                loss = torch.sum(
                    attention_mask * loss) / torch.sum(attention_mask)
                outputs = (loss, ) + outputs
            else:
                outputs = active_logits
        if obj_train == True:
            hidden_states = outputs_1[2][-2]
            # hidden_states = self.dropout(hidden_states)
            loss_obj = BCELoss(reduction='none')
            loss_sig = nn.Sigmoid()
            # sub_pos_start = self.sub_pos_emb(subject_ids[:, :1]).to(device)
            # sub_pos_end = self.sub_pos_emb(subject_ids[:, 1:]).to(device)
            # subject_start,subject_end = extrac_subject_1(hidden_states, subject_ids)
            # subject = extrac_subject(hidden_states, subject_ids)
            # subject_start = subject_start.to(device)
            # subject_end = subject_end.to(device)
            # subject = (sub_pos_start + subject_start + sub_pos_end + subject_end).to(device)
            subject = extrac_subject(hidden_states, subject_ids).to(device)
            batch_token_ids_obj = torch.add(hidden_states, subject)
            # batch_token_ids_obj = self.LayerNorm(batch_token_ids_obj)
            # batch_token_ids_obj = self.dropout(batch_token_ids_obj)
            # batch_token_ids_obj = F.dropout(batch_token_ids_obj,p=0.5)
            # batch_token_ids_obj = self.relu(self.linear(batch_token_ids_obj))
            # batch_token_ids_obj = self.dropout(batch_token_ids_obj)
            obj_logits = self.obj_classifier(batch_token_ids_obj)
            # obj_logits = self.dropout(obj_logits)
            # obj_logits = F.dropout(obj_logits,p=0.4)
            obj_logits = loss_sig(obj_logits)
            obj_logits = obj_logits**4
            obj_outputs = (obj_logits, )
            if obj_labels is not None:
                # obj_loss = loss_obj(obj_logits.view(-1, hidden_states.size()[1], self.obj_labels // 2, 2), obj_labels.float())
                obj_loss = loss_obj(
                    obj_logits.view(batch_size, -1, self.obj_labels // 2, 2),
                    obj_labels.float())
                obj_loss = torch.sum(torch.mean(obj_loss, 3), 2)
                obj_loss = torch.sum(
                    obj_loss * attention_mask) / torch.sum(attention_mask)
                s_o_loss = torch.add(obj_loss, loss)
                outputs_obj = (s_o_loss, ) + obj_outputs
            else:
                # outputs_obj = obj_logits.view(-1,hidden_states.size()[1],self.obj_labels // 2 ,2)
                outputs_obj = obj_logits.view(batch_size, -1,
                                              self.obj_labels // 2, 2)
        if obj_train == True:
            return outputs, outputs_obj  # (loss), scores, (hidden_states), (attentions)
        else:
            return outputs
Beispiel #14
0
 def __init__(self):
     super().__init__()
     self.criterion = BCELoss()
Beispiel #15
0
def train(args):
    r"""
    Run training for the chosen model using the configurations in args
    
    :param args: parsed arguments
    """
    # =====================set seeds========================
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.device != "cpu":
        torch.cuda.manual_seed(args.seed)

    # =====================import data======================
    dataset_train = ToulouseRoadNetworkDataset(
        split=args.train_split, max_prev_node=args.max_prev_node)
    dataset_valid = ToulouseRoadNetworkDataset(
        split="valid", max_prev_node=args.max_prev_node)

    dataloader_train = DataLoader(dataset_train,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  collate_fn=custom_collate_fn)
    dataloader_valid = DataLoader(dataset_valid,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  collate_fn=custom_collate_fn)

    print("Dataset splits -> Train: {} | Valid: {}\n".format(
        len(dataset_train), len(dataset_valid)))

    # =====================init models======================
    encoder, optimizer_enc = load_encoder(args)
    decoder, optimizer_dec = load_decoder(args)
    run_epoch = get_epoch_fn(args)

    # use this to continue training using existing checkpoints
    # encoder.load_state_dict(torch.load(args.checkpoints_path + "/encoder.pth"))
    # decoder.load_state_dict(torch.load(args.checkpoints_path + "/decoder.pth"))

    # =====================init losses=======================
    criterion_mse = MSELoss(reduction='mean')
    criterion_bce = BCELoss(reduction='mean')
    criterions = {"bce": criterion_bce, "mse": criterion_mse}

    # ===================random guessing====================
    loss_valid, loss_valid_adj, loss_valid_coord = run_epoch(args,
                                                             0,
                                                             dataloader_valid,
                                                             decoder,
                                                             encoder,
                                                             optimizer_dec,
                                                             optimizer_enc,
                                                             criterions,
                                                             is_eval=True)
    print_and_log(
        'Epoch {}/{} || Train loss: {:.4f} Adj: {:.4f} Coord: {:.4f} ||'
        ' Valid loss: {:.4f} Adj: {:.4f} Coord: {:.4f}'.format(
            0, args.epochs, 0, 0, 0, loss_valid, loss_valid_adj,
            loss_valid_coord), args.file_logs)

    # ========================train=========================
    start_time = time.time()
    writer = SummaryWriter(args.file_tensorboard)
    min_loss_valid = (10000000, 0)

    for epoch in range(args.epochs):
        if time.time() - start_time > args.max_time:
            break

        loss_train, loss_train_adj, loss_train_coord = run_epoch(
            args,
            epoch + 1,
            dataloader_train,
            decoder,
            encoder,
            optimizer_dec,
            optimizer_enc,
            criterions,
            is_eval=False)
        loss_valid, loss_valid_adj, loss_valid_coord = run_epoch(
            args,
            epoch + 1,
            dataloader_valid,
            decoder,
            encoder,
            optimizer_dec,
            optimizer_enc,
            criterions,
            is_eval=True)

        # ========================log and plot=========================
        if epoch % 1 == 0:
            print_and_log(
                'Epoch {}/{} || Train loss: {:.4f} Adj: {:.4f} Coord: {:.4f} ||'
                ' Valid loss: {:.4f} Adj: {:.4f} Coord: {:.4f}'.format(
                    epoch + 1, args.epochs, loss_train, loss_train_adj,
                    loss_train_coord, loss_valid, loss_valid_adj,
                    loss_valid_coord), args.file_logs)

        # ========================update curves========================
        save_losses(args, loss_train, loss_train_adj, loss_train_coord,
                    loss_valid, loss_valid_adj, loss_valid_coord)
        update_writer(writer, "Loss", loss_train, loss_valid, epoch)
        update_writer(writer, "Loss Adj", loss_train_adj, loss_valid_adj,
                      epoch)
        update_writer(writer, "Loss Coord", loss_train_coord, loss_valid_coord,
                      epoch)

        # =========================save models=========================
        if min_loss_valid[0] > loss_valid:
            min_loss_valid = (loss_valid, epoch + 1)
            torch.save(decoder.state_dict(),
                       args.checkpoints_path + "/decoder.pth")
            if encoder is not None:
                torch.save(encoder.state_dict(),
                           args.checkpoints_path + "/encoder.pth")

    print("\nTraining Completed!")
    print_and_log(
        "Minimum loss on validation set: {} at epoch {}".format(
            min_loss_valid[0], min_loss_valid[1]), args.file_logs)
Beispiel #16
0
import os
import matplotlib.pyplot as plt
import numpy as np

# TRAIN_BATCH_SIZE = 200
# TEST_BATCH_SIZE = 300

TRAIN_BATCH_SIZE = 200
TEST_BATCH_SIZE = 756

N_EPOCH = 20000
CHECK_LOSS_INTERVAL = 1
assert N_EPOCH % CHECK_LOSS_INTERVAL == 0

mse = MSELoss(reduction='none')
bce = BCELoss(reduction='none')

torch.set_printoptions(profile="full", precision=2, linewidth=10000)
torch.manual_seed(20)

MODEL_SAVE_DIR = os.path.join("int_phy", "locomotion", "pre_trained_old")


def set_loss(dataloader, net):
    total_loss = 0
    h_0 = torch.zeros(size=(NUM_HIDDEN_LAYER,
                            dataloader.batch_size, HIDDEN_STATE_SIZE)).to(
                                DEVICE)  # (num_layer, batch_size, hidden_size)
    c_0 = torch.zeros(size=(NUM_HIDDEN_LAYER,
                            dataloader.batch_size, HIDDEN_STATE_SIZE)).to(
                                DEVICE)  # (num_layer, batch_size, hidden_size)
Beispiel #17
0
def main(_run, _config, _seed, _log):
    # Setting logger
    args = _config
    logger = _log

    logger.info(args)
    logger.info('It started at: %s' % datetime.now())

    torch.manual_seed(_seed)

    device = torch.device('cuda' if args['cuda'] else "cpu")
    if args['cuda']:
        logger.info("Turning CUDA on")
    else:
        logger.info("Turning CUDA off")

    # Setting the parameter to save and loading parameters
    important_parameters = ['dbr_cnn']
    parameters_to_save = dict([(name, args[name])
                               for name in important_parameters])

    if args['load'] is not None:
        map_location = (
            lambda storage, loc: storage.cuda()) if args['cuda'] else 'cpu'
        model_info = torch.load(args['load'], map_location=map_location)
        model_state = model_info['model']

        for param_name, param_value in model_info['params'].items():
            args[param_name] = param_value
    else:
        model_state = None

    # Set basic variables
    preprocessors = PreprocessorList()
    input_handlers = []
    report_database = BugReportDatabase.fromJson(args['bug_database'])
    batchSize = args['batch_size']
    dbr_cnn_opt = args['dbr_cnn']

    # Loading word embedding and lexicon
    emb = np.load(dbr_cnn_opt["word_embedding"])
    padding_sym = "</s>"

    lexicon = Lexicon(unknownSymbol=None)
    with codecs.open(dbr_cnn_opt["lexicon"]) as f:
        for l in f:
            lexicon.put(l.strip())

    lexicon.setUnknown("UUUKNNN")
    padding_id = lexicon.getLexiconIndex(padding_sym)
    embedding = Embedding(lexicon, emb, paddingIdx=padding_id)

    logger.info("Lexicon size: %d" % (lexicon.getLen()))
    logger.info("Word Embedding size: %d" % (embedding.getEmbeddingSize()))

    # Load filters and tokenizer
    filters = loadFilters(dbr_cnn_opt['filters'])

    if dbr_cnn_opt['tokenizer'] == 'default':
        logger.info("Use default tokenizer to tokenize summary information")
        tokenizer = MultiLineTokenizer()
    elif dbr_cnn_opt['tokenizer'] == 'white_space':
        logger.info(
            "Use white space tokenizer to tokenize summary information")
        tokenizer = WhitespaceTokenizer()
    else:
        raise ArgumentError(
            "Tokenizer value %s is invalid. You should choose one of these: default and white_space"
            % dbr_cnn_opt['tokenizer'])

    # Add preprocessors
    preprocessors.append(
        DBR_CNN_CategoricalPreprocessor(dbr_cnn_opt['categorical_lexicon'],
                                        report_database))
    preprocessors.append(
        SummaryDescriptionPreprocessor(lexicon, report_database, filters,
                                       tokenizer, padding_id))

    # Add input_handlers
    input_handlers.append(DBRDCNN_CategoricalInputHandler())
    input_handlers.append(
        TextCNNInputHandler(padding_id, min(dbr_cnn_opt["window"])))

    # Create Model
    model = DBR_CNN(embedding, dbr_cnn_opt["window"], dbr_cnn_opt["nfilters"],
                    dbr_cnn_opt['update_embedding'])

    model.to(device)

    if model_state:
        model.load_state_dict(model_state)

    # Set loss function
    logger.info("Using BCE Loss")
    loss_fn = BCELoss()
    loss_no_reduction = BCELoss(reduction='none')
    cmp_collate = PairBugCollate(input_handlers,
                                 torch.float32,
                                 unsqueeze_target=True)

    # Loading the training and setting how the negative example will be generated.
    if args.get('pairs_training'):
        negative_pair_gen_opt = args.get('neg_pair_generator', )
        pairsTrainingFile = args.get('pairs_training')
        random_anchor = negative_pair_gen_opt['random_anchor']

        offlineGeneration = not (negative_pair_gen_opt is None
                                 or negative_pair_gen_opt['type'] == 'none')

        if not offlineGeneration:
            logger.info("Not generate dynamically the negative examples.")
            pair_training_reader = PairBugDatasetReader(
                pairsTrainingFile,
                preprocessors,
                randomInvertPair=args['random_switch'])
        else:
            pair_gen_type = negative_pair_gen_opt['type']
            master_id_by_bug_id = report_database.getMasterIdByBugId()

            if pair_gen_type == 'random':
                logger.info("Random Negative Pair Generator")
                training_dataset = BugDataset(
                    negative_pair_gen_opt['training'])
                bug_ids = training_dataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (training_dataset.info, len(bug_ids)))

                negative_pair_generator = RandomGenerator(
                    preprocessors, cmp_collate, negative_pair_gen_opt['rate'],
                    bug_ids, master_id_by_bug_id)

            elif pair_gen_type == 'non_negative':
                logger.info("Non Negative Pair Generator")
                training_dataset = BugDataset(
                    negative_pair_gen_opt['training'])
                bug_ids = training_dataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (training_dataset.info, len(bug_ids)))

                negative_pair_generator = NonNegativeRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negative_pair_gen_opt['rate'],
                    bug_ids,
                    master_id_by_bug_id,
                    negative_pair_gen_opt['n_tries'],
                    device,
                    randomAnchor=random_anchor)
            elif pair_gen_type == 'misc_non_zero':
                logger.info("Misc Non Zero Pair Generator")
                training_dataset = BugDataset(
                    negative_pair_gen_opt['training'])
                bug_ids = training_dataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (training_dataset.info, len(bug_ids)))

                negative_pair_generator = MiscNonZeroRandomGen(
                    preprocessors, cmp_collate, negative_pair_gen_opt['rate'],
                    bug_ids, training_dataset.duplicateIds,
                    master_id_by_bug_id, device,
                    negative_pair_gen_opt['n_tries'],
                    negative_pair_gen_opt['random_anchor'])
            elif pair_gen_type == 'random_k':
                logger.info("Random K Negative Pair Generator")
                training_dataset = BugDataset(
                    negative_pair_gen_opt['training'])
                bug_ids = training_dataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (training_dataset.info, len(bug_ids)))

                negative_pair_generator = KRandomGenerator(
                    preprocessors, cmp_collate, negative_pair_gen_opt['rate'],
                    bug_ids, master_id_by_bug_id, negative_pair_gen_opt['k'],
                    device)
            elif pair_gen_type == "pre":
                logger.info("Pre-selected list generator")
                negative_pair_generator = PreSelectedGenerator(
                    negative_pair_gen_opt['pre_list_file'], preprocessors,
                    negative_pair_gen_opt['rate'], master_id_by_bug_id,
                    negative_pair_gen_opt['preselected_length'])
            elif pair_gen_type == "misc_non_zero_pre":
                logger.info("Pre-selected list generator")

                negativePairGenerator1 = PreSelectedGenerator(
                    negative_pair_gen_opt['pre_list_file'], preprocessors,
                    negative_pair_gen_opt['rate'], master_id_by_bug_id,
                    negative_pair_gen_opt['preselected_length'])

                training_dataset = BugDataset(
                    negative_pair_gen_opt['training'])
                bug_ids = training_dataset.bugIds

                negativePairGenerator2 = NonNegativeRandomGenerator(
                    preprocessors, cmp_collate, negative_pair_gen_opt['rate'],
                    bug_ids, master_id_by_bug_id, device,
                    negative_pair_gen_opt['n_tries'])

                negative_pair_generator = MiscOfflineGenerator(
                    (negativePairGenerator1, negativePairGenerator2))
            else:
                raise ArgumentError(
                    "Offline generator is invalid (%s). You should choose one of these: random, hard and pre"
                    % pair_gen_type)

            pair_training_reader = PairBugDatasetReader(
                pairsTrainingFile,
                preprocessors,
                negative_pair_generator,
                randomInvertPair=args['random_switch'])

        training_loader = DataLoader(pair_training_reader,
                                     batch_size=batchSize,
                                     collate_fn=cmp_collate.collate,
                                     shuffle=True)
        logger.info("Training size: %s" % (len(training_loader.dataset)))

    # load validation
    if args.get('pairs_validation'):
        pair_validation_reader = PairBugDatasetReader(
            args.get('pairs_validation'), preprocessors)
        validation_loader = DataLoader(pair_validation_reader,
                                       batch_size=batchSize,
                                       collate_fn=cmp_collate.collate)

        logger.info("Validation size: %s" % (len(validation_loader.dataset)))
    else:
        validation_loader = None
    """
    Training and evaluate the model. 
    """
    optimizer_opt = args.get('optimizer', 'adam')

    if optimizer_opt == 'sgd':
        logger.info('SGD')
        optimizer = optim.SGD(model.parameters(),
                              lr=args['lr'],
                              weight_decay=args['l2'],
                              momentum=args['momentum'])
    elif optimizer_opt == 'adam':
        logger.info('Adam')
        optimizer = optim.Adam(model.parameters(),
                               lr=args['lr'],
                               weight_decay=args['l2'])

    # Recall rate
    ranking_scorer = DBR_CNN_Scorer(preprocessors[0], preprocessors[1],
                                    input_handlers[0], input_handlers[1],
                                    model, device, args['ranking_batch_size'])
    recallEstimationTrainOpt = args.get('recall_estimation_train')

    if recallEstimationTrainOpt:
        preselectListRankingTrain = PreselectListRanking(
            recallEstimationTrainOpt)

    recallEstimationOpt = args.get('recall_estimation')

    if recallEstimationOpt:
        preselect_list_ranking = PreselectListRanking(recallEstimationOpt)

    lr_scheduler_opt = args.get('lr_scheduler', None)

    if lr_scheduler_opt is None or lr_scheduler_opt['type'] == 'constant':
        logger.info("Scheduler: Constant")
        lr_sched = None
    elif lr_scheduler_opt["type"] == 'step':
        logger.info("Scheduler: StepLR (step:%s, decay:%f)" %
                    (lr_scheduler_opt["step_size"], args["decay"]))
        lr_sched = StepLR(optimizer, lr_scheduler_opt["step_size"],
                          lr_scheduler_opt["decay"])
    elif lr_scheduler_opt["type"] == 'exp':
        logger.info("Scheduler: ExponentialLR (decay:%f)" %
                    (lr_scheduler_opt["decay"]))
        lr_sched = ExponentialLR(optimizer, lr_scheduler_opt["decay"])
    elif lr_scheduler_opt["type"] == 'linear':
        logger.info(
            "Scheduler: Divide by (1 + epoch * decay) ---- (decay:%f)" %
            (lr_scheduler_opt["decay"]))

        lrDecay = lr_scheduler_opt["decay"]
        lr_sched = LambdaLR(optimizer, lambda epoch: 1 /
                            (1.0 + epoch * lrDecay))
    else:
        raise ArgumentError(
            "LR Scheduler is invalid (%s). You should choose one of these: step, exp and linear "
            % pair_gen_type)

    # Set training functions
    def trainingIteration(engine, batch):
        model.train()

        optimizer.zero_grad()
        x, y = cmp_collate.to(batch, device)
        output = model(*x)
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()
        return loss, output, y

    trainer = Engine(trainingIteration)
    negTarget = 0.0 if isinstance(loss_fn, NLLLoss) else -1.0

    trainingMetrics = {
        'training_loss':
        AverageLoss(loss_fn),
        'training_acc':
        AccuracyWrapper(output_transform=thresholded_output_transform),
        'training_precision':
        PrecisionWrapper(output_transform=thresholded_output_transform),
        'training_recall':
        RecallWrapper(output_transform=thresholded_output_transform),
    }

    # Add metrics to trainer
    for name, metric in trainingMetrics.items():
        metric.attach(trainer, name)

    # Set validation functions
    def validationIteration(engine, batch):
        model.eval()

        with torch.no_grad():
            x, y = cmp_collate.to(batch, device)
            y_pred = model(*x)

            return y_pred, y

    validationMetrics = {
        'validation_loss':
        LossWrapper(loss_fn),
        'validation_acc':
        AccuracyWrapper(output_transform=thresholded_output_transform),
        'validation_precision':
        PrecisionWrapper(output_transform=thresholded_output_transform),
        'validation_recall':
        RecallWrapper(output_transform=thresholded_output_transform),
    }
    evaluator = Engine(validationIteration)

    # Add metrics to evaluator
    for name, metric in validationMetrics.items():
        metric.attach(evaluator, name)

    @trainer.on(Events.EPOCH_STARTED)
    def onStartEpoch(engine):
        epoch = engine.state.epoch
        logger.info("Epoch: %d" % epoch)

        if lr_sched:
            lr_sched.step()

        logger.info("LR: %s" % str(optimizer.param_groups[0]["lr"]))

    @trainer.on(Events.EPOCH_COMPLETED)
    def onEndEpoch(engine):
        epoch = engine.state.epoch

        logMetrics(_run, logger, engine.state.metrics, epoch)

        # Evaluate Training
        if validation_loader:
            evaluator.run(validation_loader)
            logMetrics(_run, logger, evaluator.state.metrics, epoch)

        lastEpoch = args['epochs'] - epoch == 0

        if recallEstimationTrainOpt and (epoch % args['rr_train_epoch'] == 0):
            logRankingResult(_run, logger, preselectListRankingTrain,
                             ranking_scorer, report_database, None, epoch,
                             "train")
            ranking_scorer.free()

        if recallEstimationOpt and (epoch % args['rr_val_epoch'] == 0):
            logRankingResult(_run, logger, preselect_list_ranking,
                             ranking_scorer, report_database,
                             args.get("ranking_result_file"), epoch,
                             "validation")
            ranking_scorer.free()

        if not lastEpoch:
            pair_training_reader.sampleNewNegExamples(model, loss_no_reduction)

        if args.get('save'):
            save_by_epoch = args['save_by_epoch']

            if save_by_epoch and epoch in save_by_epoch:
                file_name, file_extension = os.path.splitext(args['save'])
                file_path = file_name + '_epoch_{}'.format(
                    epoch) + file_extension
            else:
                file_path = args['save']

            modelInfo = {
                'model': model.state_dict(),
                'params': parameters_to_save
            }

            logger.info("==> Saving Model: %s" % file_path)
            torch.save(modelInfo, file_path)

    if args.get('pairs_training'):
        trainer.run(training_loader, max_epochs=args['epochs'])
    elif args.get('pairs_validation'):
        # Evaluate Training
        evaluator.run(validation_loader)
        logMetrics(logger, evaluator.state.metrics)

        if recallEstimationOpt:
            logRankingResult(_run, logger, preselect_list_ranking,
                             ranking_scorer, report_database,
                             args.get("ranking_result_file"), 0, "validation")

    # Test Dataset (accuracy, recall, precision, F1)
    pair_test_dataset = args.get('pair_test_dataset')

    if pair_test_dataset is not None and len(pair_test_dataset) > 0:
        pairTestReader = PairBugDatasetReader(pair_test_dataset, preprocessors)
        testLoader = DataLoader(pairTestReader,
                                batch_size=batchSize,
                                collate_fn=cmp_collate.collate)

        if not isinstance(cmp_collate, PairBugCollate):
            raise NotImplementedError(
                'Evaluation of pairs using tanh was not implemented yet')

        logger.info("Test size: %s" % (len(testLoader.dataset)))

        testMetrics = {
            'test_accuracy':
            ignite.metrics.Accuracy(
                output_transform=thresholded_output_transform),
            'test_precision':
            ignite.metrics.Precision(
                output_transform=thresholded_output_transform),
            'test_recall':
            ignite.metrics.Recall(
                output_transform=thresholded_output_transform),
            'test_predictions':
            PredictionCache(),
        }
        test_evaluator = Engine(validationIteration)

        # Add metrics to evaluator
        for name, metric in testMetrics.items():
            metric.attach(test_evaluator, name)

        test_evaluator.run(testLoader)

        for metricName, metricValue in test_evaluator.state.metrics.items():
            metric = testMetrics[metricName]

            if isinstance(metric, ignite.metrics.Accuracy):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'value': metricValue,
                    'epoch': None,
                    'correct': metric._num_correct,
                    'total': metric._num_examples
                })
                _run.log_scalar(metricName, metricValue)
            elif isinstance(metric,
                            (ignite.metrics.Precision, ignite.metrics.Recall)):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'value': metricValue,
                    'epoch': None,
                    'tp': metric._true_positives.item(),
                    'total_positive': metric._positives.item()
                })
                _run.log_scalar(metricName, metricValue)
            elif isinstance(metric, ConfusionMatrix):
                acc = cmAccuracy(metricValue)
                prec = cmPrecision(metricValue, False)
                recall = cmRecall(metricValue, False)
                f1 = 2 * (prec * recall) / (prec + recall + 1e-15)

                logger.info({
                    'type':
                    'metric',
                    'label':
                    metricName,
                    'accuracy':
                    np.float(acc),
                    'precision':
                    prec.cpu().numpy().tolist(),
                    'recall':
                    recall.cpu().numpy().tolist(),
                    'f1':
                    f1.cpu().numpy().tolist(),
                    'confusion_matrix':
                    metricValue.cpu().numpy().tolist(),
                    'epoch':
                    None
                })

                _run.log_scalar('test_f1', f1[1])
            elif isinstance(metric, PredictionCache):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'predictions': metric.predictions
                })

    # Calculate recall rate
    recall_rate_opt = args.get('recall_rate', {'type': 'none'})
    if recall_rate_opt['type'] != 'none':
        if recall_rate_opt['type'] == 'sun2011':
            logger.info("Calculating recall rate: {}".format(
                recall_rate_opt['type']))
            recall_rate_dataset = BugDataset(recall_rate_opt['dataset'])

            ranking_class = SunRanking(report_database, recall_rate_dataset,
                                       recall_rate_opt['window'])
            # We always group all bug reports by master in the results in the sun 2011 methodology
            group_by_master = True
        elif recall_rate_opt['type'] == 'deshmukh':
            logger.info("Calculating recall rate: {}".format(
                recall_rate_opt['type']))
            recall_rate_dataset = BugDataset(recall_rate_opt['dataset'])
            ranking_class = DeshmukhRanking(report_database,
                                            recall_rate_dataset)
            group_by_master = recall_rate_opt['group_by_master']
        else:
            raise ArgumentError(
                "recall_rate.type is invalid (%s). You should choose one of these: step, exp and linear "
                % recall_rate_opt['type'])

        logRankingResult(
            _run,
            logger,
            ranking_class,
            ranking_scorer,
            report_database,
            recall_rate_opt["result_file"],
            0,
            None,
            group_by_master,
        )
Beispiel #18
0
def main():
    all_files = ""
    for file in Path(__file__).parent.resolve().glob('*.py'):
        with open(str(file), 'r', encoding='utf-8') as f:
            all_files += f.read()
    print(hashlib.md5(all_files.encode()).hexdigest())

    NUM_EPOCHS = int(10)
    parser = ArgumentParser()
    parser.add_argument("--from-checkpoint")
    parser.add_argument("--train-epochs", type=int)
    args = parser.parse_args()

    datasets_config = define_dataset_config()
    tasks_config = define_tasks_config(datasets_config)

    task_actions = []
    for task in iter(Task):
        if task not in [Task.SNLI, Task.SciTail,
                        Task.WNLI]:  # Train only GLUE task
            train_loader = tasks_config[task]["train_loader"]
            task_actions.extend([task] * len(train_loader))
    epoch_steps = len(task_actions)

    model = MT_BERT()
    model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=5e-5)
    initial_epoch = 1
    training_start = datetime.datetime.now().isoformat()

    if args.from_checkpoint:
        print("Loading from checkpoint")
        checkpoint = torch.load(args.from_checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initial_epoch = checkpoint['epoch'] + 1
        training_start = checkpoint["training_start"]
        warmup_scheduler = None
        lr_scheduler = None
    else:
        print("Starting training from scratch")
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda step: 1.0)
        warmup_scheduler = warmup.LinearWarmup(
            optimizer, warmup_period=(epoch_steps * NUM_EPOCHS) // 10)

    if args.train_epochs:
        NUM_EPOCHS = initial_epoch + args.train_epochs - 1
    print(
        f"------------------ training-start:  {training_start} --------------------------)"
    )

    losses = {
        'BCELoss': BCELoss(),
        'CrossEntropyLoss': CrossEntropyLoss(),
        'MSELoss': MSELoss()
    }
    for name, loss in losses.items():
        losses[name].to(device)

    for epoch in range(initial_epoch, NUM_EPOCHS + 1):
        with stream_redirect_tqdm() as orig_stdout:
            epoch_bar = tqdm(sample(task_actions, len(task_actions)),
                             file=orig_stdout)
            model.train()

            for task_action in epoch_bar:
                train_loader = tasks_config[task_action]["train_loader"]
                epoch_bar.set_description(
                    f"current task: {task_action.name} in epoch:{epoch}")

                data = next(iter(train_loader))

                optimizer.zero_grad(set_to_none=True)

                data_columns = [
                    col for col in tasks_config[task_action]["columns"]
                    if col != "label"
                ]
                input_data = list(zip(*(data[col] for col in data_columns)))

                label = data["label"]
                if label.dtype == torch.float64:
                    label = label.to(torch.float32)
                if task_action == Task.QNLI:
                    label = label.to(torch.float32)

                task_criterion = losses[MT_BERT.loss_for_task(task_action)]

                if len(data_columns) == 1:
                    input_data = list(map(operator.itemgetter(0), input_data))

                label = label.to(device)
                train_minibatch(input_data=input_data,
                                task=task_action,
                                label=label,
                                model=model,
                                task_criterion=task_criterion,
                                optimizer=optimizer)

                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
                optimizer.step()

                if warmup_scheduler:
                    lr_scheduler.step()
                    warmup_scheduler.dampen()

            results_folder = Path(f"results_{training_start}")
            results_folder.mkdir(exist_ok=True)

            models_path = results_folder / "saved_models"
            models_path.mkdir(exist_ok=True)
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'training_start': training_start
                }, str(models_path / f'epoch_{epoch}.tar'))

            model.eval()
            val_results = {}
            with torch.no_grad():
                task_bar = tqdm([
                    task for task in Task
                    if task not in [Task.SNLI, Task.SciTail, Task.WNLI]
                ],
                                file=orig_stdout)
                for task in task_bar:
                    task_bar.set_description(task.name)
                    val_loader = tasks_config[task]["val_loader"]

                    task_predicted_labels = torch.empty(0, device=device)
                    task_labels = torch.empty(0, device=device)
                    for val_data in val_loader:
                        data_columns = [
                            col for col in tasks_config[task]["columns"]
                            if col != "label"
                        ]
                        input_data = list(
                            zip(*(val_data[col] for col in data_columns)))
                        label = val_data["label"].to(device)

                        if len(data_columns) == 1:
                            input_data = list(
                                map(operator.itemgetter(0), input_data))

                        model_output = model(input_data, task)

                        if task == Task.QNLI:
                            predicted_label = torch.round(model_output)
                        elif task.num_classes() > 1:
                            predicted_label = torch.argmax(model_output, -1)
                        else:
                            predicted_label = model_output

                        task_predicted_labels = torch.hstack(
                            (task_predicted_labels, predicted_label.view(-1)))
                        task_labels = torch.hstack((task_labels, label))

                    metrics = datasets_config[task].metrics
                    for metric in metrics:
                        metric_result = metric(task_labels.cpu(),
                                               task_predicted_labels.cpu())
                        if type(metric_result) == tuple or type(
                                metric_result
                        ) == scipy.stats.stats.SpearmanrResult:
                            metric_result = metric_result[0]
                        val_results[task.name, metric.__name__] = metric_result
                        print(
                            f"val_results[{task.name}, {metric.__name__}] = {val_results[task.name, metric.__name__]}"
                        )
            data_frame = pd.DataFrame(data=val_results, index=[epoch])
            data_frame.to_csv(str(results_folder / f"train_GLUE_results.csv"),
                              mode='a',
                              index_label='Epoch')
Beispiel #19
0
def standard_cross_entropy(s_hat, log_r_hat, t0_hat, t1_hat, y_true, r_true,
                           t0_true, t1_true):
    s_hat = 1.0 / (1.0 + torch.exp(log_r_hat))

    return BCELoss()(s_hat, y_true)
Beispiel #20
0
def objective(trial):

    TREE_DEPTH = trial.suggest_int('TREE_DEPTH', 2, 6)
    REG = trial.suggest_loguniform('REG', 1e-3, 1e3)
    print(f'TREE_DEPTH={TREE_DEPTH}, REG={REG}')

    if not LINEAR:
        MLP_LAYERS = trial.suggest_int('MLP_LAYERS', 2, 7)
        DROPOUT = trial.suggest_uniform('DROPOUT', 0.0, 0.5)
        print(f'MLP_LAYERS={MLP_LAYERS}, DROPOUT={DROPOUT}')

    pruning = REG > 0

    if LINEAR:
        save_dir = root_dir / "depth={}/reg={}/seed={}".format(
            TREE_DEPTH, REG, SEED)
        model = LTBinaryClassifier(TREE_DEPTH,
                                   data.X_train.shape[1],
                                   reg=REG,
                                   linear=LINEAR)

    else:
        save_dir = root_dir / "depth={}/reg={}/mlp-layers={}/dropout={}/seed={}".format(
            TREE_DEPTH, REG, MLP_LAYERS, DROPOUT, SEED)
        model = LTBinaryClassifier(TREE_DEPTH,
                                   data.X_train.shape[1],
                                   reg=REG,
                                   linear=LINEAR,
                                   layers=MLP_LAYERS,
                                   dropout=DROPOUT)

    print(model.count_parameters(), "model's parameters")

    save_dir.mkdir(parents=True, exist_ok=True)

    # init optimizer
    optimizer = QHAdam(model.parameters(),
                       lr=LR,
                       nus=(0.7, 1.0),
                       betas=(0.995, 0.998))

    # init learning rate scheduler
    lr_scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=2)

    # init loss
    criterion = BCELoss(reduction="sum")

    # evaluation criterion => error rate
    eval_criterion = lambda x, y: (x.long() != y.long()).sum()

    # init train-eval monitoring
    monitor = MonitorTree(pruning, save_dir)

    state = {
        'batch-size': BATCH_SIZE,
        'loss-function': 'BCE',
        'learning-rate': LR,
        'seed': SEED,
        'dataset': DATA_NAME,
    }

    best_val_loss = float("inf")
    best_e = -1
    no_improv = 0
    for e in range(EPOCHS):
        train_stochastic(trainloader,
                         model,
                         optimizer,
                         criterion,
                         epoch=e,
                         monitor=monitor)

        val_loss = evaluate(valloader,
                            model, {'ER': eval_criterion},
                            epoch=e,
                            monitor=monitor)

        no_improv += 1
        if val_loss['ER'] < best_val_loss:
            best_val_loss = val_loss['ER']
            best_e = e
            no_improv = 0
            # save_model(model, optimizer, state, save_dir)

        # reduce learning rate if needed
        lr_scheduler.step(val_loss['ER'])
        monitor.write(model, e, train={"lr": optimizer.param_groups[0]['lr']})

        trial.report(val_loss['ER'], e)
        # Handle pruning based on the intermediate value.
        if trial.should_prune() or np.isnan(val_loss['ER']):
            monitor.close()
            raise optuna.TrialPruned()

        if no_improv == 10:
            break

    print("Best validation ER:", best_val_loss)
    monitor.close()

    return best_val_loss
Beispiel #21
0
def loss_function(recon_x, x, mu, logvar):
    BCE = reconstruction_function(recon_x, x)

    # https://arxiv.org/abs/1312.6114 (Appendix B)
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)

    return BCE + KLD


model = VAE()
if GPU_NUMS > 0:
    model.cuda()

reconstruction_function = BCELoss()
reconstruction_function.size_average = False
optimizer = Adam(model.parameters(), lr=1e-4)

dataset = ImageFolder(root='/input/face/64_crop',
                      transform=Compose([ToTensor()]))
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

bar = ProgressBar(EPOCH, len(train_loader), "Loss:%.3f")

model.train()
train_loss = 0
for epoch in range(EPOCH):
    for ii, (image, label) in enumerate(train_loader):
        mini_batch = image.shape[0]
        data = Variable(image.cuda() if GPU_NUMS > 0 else image)
def main_train():
    discriminator = PatchGANDiscriminator()
    generator = Generator()

    learning_rate = .005
    updateable_params = filter(lambda p: p.requires_grad,
                               discriminator.parameters())
    discriminator_optimizer = optim.Adam(updateable_params, lr=learning_rate)
    updateable_params = filter(lambda p: p.requires_grad,
                               generator.parameters())
    generator_optimizer = optim.Adam(updateable_params, lr=learning_rate)
    criterion = BCELoss().cuda()

    if use_gpu:
        discriminator = discriminator.cuda()
        generator = generator.cuda()
        criterion = criterion.cuda()

    train_loader, gan_loader = load_loaders()

    for i in range(1, 10):
        print(f"GAN EPOCH: {i}")
        # Train discriminator on color images
        train(discriminator,
              None,
              discriminator_optimizer,
              criterion,
              train_loader,
              use_gpu,
              epochs=1,
              total_iter=0,
              epoch_start=0)
        # Train discriminator on fakes
        train(discriminator,
              generator,
              generator_optimizer,
              criterion,
              gan_loader,
              use_gpu,
              epochs=10,
              total_iter=0,
              epoch_start=0)
        real_score, fake_score = score_it(discriminator, generator)
        print("REAL SCORE:", real_score)
        print("FAKE SCORE:", fake_score)

    print("These should all be 1")
    debug_dat_model(discriminator, None, train_loader)
    print("These should all be 0")
    debug_dat_model(discriminator, generator, gan_loader)

    now_time = time.strftime("%H-%M-%S")
    os.makedirs("./pickles", exist_ok=True)

    disc_file = f"./pickles/discriminator-{now_time}.p"
    gen_file = f"./pickles/generator-{now_time}.p"
    print("Saving to: ")
    print(disc_file)
    print(gen_file)
    torch.save(discriminator, disc_file)
    torch.save(generator, gen_file)
Beispiel #23
0
        best_test_acc = ckpt['acc']
        agent.load_state_dict(ckpt['net'])
    else:
        # train starting with pre-trained network
        print('train from scratch...')
        pretrained_dict = torch.load(os.path.join('policy_ckpt',
                                                  'resnet10.t7'))
        model_dict = agent.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        agent.load_state_dict(model_dict)

criterion = BCELoss()
optimizer = optim.Adam(agent.parameters(), lr=args.lr)

# only parallelize for imagenet
if dataset_type == utils.Data.imagenet and device == 'cuda':
    agent = torch.nn.DataParallel(agent)
    rnet = torch.nn.DataParallel(rnet)

# using gpu
agent.to(device)
rnet.to(device)

if args.test_policy:
    epoch = 0
    configure(os.path.join(rootpath, 'log/test_policy'), flush_secs=5)
    test_policy(epoch)
读入数据并进行预处理
'''
transform = Compose([
    Scale(IMG_SIZE),
    ToTensor(),
    Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    # MNIST('data', train=True, download=True, transform=transform),
    MNISTDataSet('../ganData/mnist.npz', train=True, transform=transform),
    batch_size=BATCH_SIZE,
    shuffle=True)
'''
开始训练
'''
BCE_loss = BCELoss()
G_optimizer = Adam(Net_G.parameters(), lr=LR, betas=(0.5, 0.999))
D_optimizer = Adam(Net_D.parameters(), lr=LR, betas=(0.5, 0.999))
bar = ProgressBar(EPOCHS, len(train_loader), "D Loss:%.3f; G Loss:%.3f")

# label preprocess
onehot = torch.zeros(10, 10)
onehot = onehot.scatter_(
    1,
    torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).view(10, 1),
    1).view(10, 10, 1, 1)
fill = torch.zeros([10, 10, IMG_SIZE, IMG_SIZE])

for i in range(10):
    fill[i, i, :, :] = 1
Beispiel #25
0
 def __init__(self, i_logging: int, discriminator: Discriminator,
              generator: Generator):
     self.loss = BCELoss()
     super().__init__(i_logging, discriminator, generator)
Beispiel #26
0
def main(_run, _config, _seed, _log):
    """

    :param _run:
    :param _config:
    :param _seed:
    :param _log:
    :return:
    """
    """
    Setting and loading parameters
    """
    # Setting logger
    args = _config
    logger = _log

    logger.info(args)
    logger.info('It started at: %s' % datetime.now())

    torch.manual_seed(_seed)
    bugReportDatabase = BugReportDatabase.fromJson(args['bug_database'])
    paddingSym = "</s>"
    batchSize = args['batch_size']

    device = torch.device('cuda' if args['cuda'] else "cpu")

    if args['cuda']:
        logger.info("Turning CUDA on")
    else:
        logger.info("Turning CUDA off")

    # It is the folder where the preprocessed information will be stored.
    cacheFolder = args['cache_folder']

    # Setting the parameter to save and loading parameters
    importantParameters = ['compare_aggregation', 'categorical']
    parametersToSave = dict([(parName, args[parName])
                             for parName in importantParameters])

    if args['load'] is not None:
        mapLocation = (
            lambda storage, loc: storage.cuda()) if args['cuda'] else 'cpu'
        modelInfo = torch.load(args['load'], map_location=mapLocation)
        modelState = modelInfo['model']

        for paramName, paramValue in modelInfo['params'].items():
            args[paramName] = paramValue
    else:
        modelState = None

    preprocessors = PreprocessorList()
    inputHandlers = []

    categoricalOpt = args.get('categorical')

    if categoricalOpt is not None and len(categoricalOpt) != 0:
        categoricalEncoder, _, _ = processCategoricalParam(
            categoricalOpt, bugReportDatabase, inputHandlers, preprocessors,
            None, logger)
    else:
        categoricalEncoder = None

    filterInputHandlers = []

    compareAggOpt = args['compare_aggregation']
    databasePath = args['bug_database']

    # Loading word embedding
    if compareAggOpt["lexicon"]:
        emb = np.load(compareAggOpt["word_embedding"])

        lexicon = Lexicon(unknownSymbol=None)
        with codecs.open(compareAggOpt["lexicon"]) as f:
            for l in f:
                lexicon.put(l.strip())

        lexicon.setUnknown("UUUKNNN")
        paddingId = lexicon.getLexiconIndex(paddingSym)
        embedding = Embedding(lexicon, emb, paddingIdx=paddingId)

        logger.info("Lexicon size: %d" % (lexicon.getLen()))
        logger.info("Word Embedding size: %d" % (embedding.getEmbeddingSize()))
    elif compareAggOpt["word_embedding"]:
        # todo: Allow use embeddings and other representation
        lexicon, embedding = Embedding.fromFile(
            compareAggOpt['word_embedding'],
            'UUUKNNN',
            hasHeader=False,
            paddingSym=paddingSym)
        logger.info("Lexicon size: %d" % (lexicon.getLen()))
        logger.info("Word Embedding size: %d" % (embedding.getEmbeddingSize()))
        paddingId = lexicon.getLexiconIndex(paddingSym)
    else:
        embedding = None

    if compareAggOpt["norm_word_embedding"]:
        embedding.zscoreNormalization()

    # Tokenizer
    if compareAggOpt['tokenizer'] == 'default':
        logger.info("Use default tokenizer to tokenize summary information")
        tokenizer = MultiLineTokenizer()
    elif compareAggOpt['tokenizer'] == 'white_space':
        logger.info(
            "Use white space tokenizer to tokenize summary information")
        tokenizer = WhitespaceTokenizer()
    else:
        raise ArgumentError(
            "Tokenizer value %s is invalid. You should choose one of these: default and white_space"
            % compareAggOpt['tokenizer'])

    # Preparing input handlers, preprocessors and cache
    minSeqSize = max(compareAggOpt['aggregate']["window"]
                     ) if compareAggOpt['aggregate']["model"] == "cnn" else -1
    bow = compareAggOpt.get('bow', False)
    freq = compareAggOpt.get('frequency', False) and bow

    logger.info("BoW={} and TF={}".format(bow, freq))

    if compareAggOpt['extractor'] is not None:
        # Use summary and description (concatenated) to address this problem
        logger.info("Using Summary and Description information.")
        # Loading Filters
        extractorFilters = loadFilters(compareAggOpt['extractor']['filters'])

        arguments = (databasePath, compareAggOpt['word_embedding'],
                     str(compareAggOpt['lexicon']), ' '.join(
                         sorted([
                             fil.__class__.__name__ for fil in extractorFilters
                         ])), compareAggOpt['tokenizer'], str(bow), str(freq),
                     SABDEncoderPreprocessor.__name__)

        inputHandlers.append(SABDInputHandler(paddingId, minSeqSize))
        extractorCache = PreprocessingCache(cacheFolder, arguments)

        if bow:
            extractorPreprocessor = SABDBoWPreprocessor(
                lexicon, bugReportDatabase, extractorFilters, tokenizer,
                paddingId, freq, extractorCache)
        else:
            extractorPreprocessor = SABDEncoderPreprocessor(
                lexicon, bugReportDatabase, extractorFilters, tokenizer,
                paddingId, extractorCache)
        preprocessors.append(extractorPreprocessor)

    # Create model
    model = SABD(embedding, categoricalEncoder, compareAggOpt['extractor'],
                 compareAggOpt['matching'], compareAggOpt['aggregate'],
                 compareAggOpt['classifier'], freq)

    if args['loss'] == 'bce':
        logger.info("Using BCE Loss: margin={}".format(args['margin']))
        lossFn = BCELoss()
        lossNoReduction = BCELoss(reduction='none')
        cmp_collate = PairBugCollate(inputHandlers,
                                     torch.float32,
                                     unsqueeze_target=True)
    elif args['loss'] == 'triplet':
        logger.info("Using Triplet Loss: margin={}".format(args['margin']))
        lossFn = TripletLoss(args['margin'])
        lossNoReduction = TripletLoss(args['margin'], reduction='none')
        cmp_collate = TripletBugCollate(inputHandlers)

    model.to(device)

    if modelState:
        model.load_state_dict(modelState)
    """
    Loading the training and validation. Also, it sets how the negative example will be generated.
    """
    # load training
    if args.get('pairs_training'):
        negativePairGenOpt = args.get('neg_pair_generator', )
        trainingFile = args.get('pairs_training')

        offlineGeneration = not (negativePairGenOpt is None
                                 or negativePairGenOpt['type'] == 'none')
        masterIdByBugId = bugReportDatabase.getMasterIdByBugId()
        randomAnchor = negativePairGenOpt['random_anchor']

        if not offlineGeneration:
            logger.info("Not generate dynamically the negative examples.")
            negativePairGenerator = None
        else:
            pairGenType = negativePairGenOpt['type']

            if pairGenType == 'random':
                logger.info("Random Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = RandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    randomAnchor=randomAnchor)

            elif pairGenType == 'non_negative':
                logger.info("Non Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = NonNegativeRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == 'misc_non_zero':
                logger.info("Misc Non Zero Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = MiscNonZeroRandomGen(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    trainingDataset.duplicateIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == 'product_component':
                logger.info("Product Component Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = ProductComponentRandomGen(
                    bugReportDatabase,
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)

            elif pairGenType == 'random_k':
                logger.info("Random K Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = KRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['k'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == "pre":
                logger.info("Pre-selected list generator")
                negativePairGenerator = PreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

            elif pairGenType == "positive_pre":
                logger.info("Positive Pre-selected list generator")
                negativePairGenerator = PositivePreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)
            elif pairGenType == "misc_non_zero_pre":
                logger.info("Misc: non-zero and Pre-selected list generator")
                negativePairGenerator1 = PreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                negativePairGenerator2 = NonNegativeRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)

                negativePairGenerator = MiscOfflineGenerator(
                    (negativePairGenerator1, negativePairGenerator2))
            elif pairGenType == "misc_non_zero_positive_pre":
                logger.info(
                    "Misc: non-zero and Positive Pre-selected list generator")
                negativePairGenerator1 = PositivePreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                negativePairGenerator2 = NonNegativeRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)

                negativePairGenerator = MiscOfflineGenerator(
                    (negativePairGenerator1, negativePairGenerator2))

            else:
                raise ArgumentError(
                    "Offline generator is invalid (%s). You should choose one of these: random, hard and pre"
                    % pairGenType)

        if isinstance(lossFn, BCELoss):
            training_reader = PairBugDatasetReader(
                trainingFile,
                preprocessors,
                negativePairGenerator,
                randomInvertPair=args['random_switch'])
        elif isinstance(lossFn, TripletLoss):
            training_reader = TripletBugDatasetReader(
                trainingFile,
                preprocessors,
                negativePairGenerator,
                randomInvertPair=args['random_switch'])

        trainingLoader = DataLoader(training_reader,
                                    batch_size=batchSize,
                                    collate_fn=cmp_collate.collate,
                                    shuffle=True)
        logger.info("Training size: %s" % (len(trainingLoader.dataset)))

    # load validation
    if args.get('pairs_validation'):
        if isinstance(lossFn, BCELoss):
            validation_reader = PairBugDatasetReader(
                args.get('pairs_validation'), preprocessors)
        elif isinstance(lossFn, TripletLoss):
            validation_reader = TripletBugDatasetReader(
                args.get('pairs_validation'), preprocessors)

        validationLoader = DataLoader(validation_reader,
                                      batch_size=batchSize,
                                      collate_fn=cmp_collate.collate)

        logger.info("Validation size: %s" % (len(validationLoader.dataset)))
    else:
        validationLoader = None
    """
    Training and evaluate the model. 
    """
    optimizer_opt = args.get('optimizer', 'adam')

    if optimizer_opt == 'sgd':
        logger.info('SGD')
        optimizer = optim.SGD(model.parameters(),
                              lr=args['lr'],
                              weight_decay=args['l2'])
    elif optimizer_opt == 'adam':
        logger.info('Adam')
        optimizer = optim.Adam(model.parameters(),
                               lr=args['lr'],
                               weight_decay=args['l2'])

    # Recall rate
    rankingScorer = GeneralScorer(
        model, preprocessors, device,
        PairBugCollate(inputHandlers, ignore_target=True),
        args['ranking_batch_size'], args['ranking_n_workers'])
    recallEstimationTrainOpt = args.get('recall_estimation_train')

    if recallEstimationTrainOpt:
        preselectListRankingTrain = PreselectListRanking(
            recallEstimationTrainOpt, args['sample_size_rr_tr'])

    recallEstimationOpt = args.get('recall_estimation')

    if recallEstimationOpt:
        preselectListRanking = PreselectListRanking(recallEstimationOpt,
                                                    args['sample_size_rr_val'])

    # LR scheduler
    lrSchedulerOpt = args.get('lr_scheduler', None)

    if lrSchedulerOpt is None:
        logger.info("Scheduler: Constant")
        lrSched = None
    elif lrSchedulerOpt["type"] == 'step':
        logger.info("Scheduler: StepLR (step:%s, decay:%f)" %
                    (lrSchedulerOpt["step_size"], args["decay"]))
        lrSched = StepLR(optimizer, lrSchedulerOpt["step_size"],
                         lrSchedulerOpt["decay"])
    elif lrSchedulerOpt["type"] == 'exp':
        logger.info("Scheduler: ExponentialLR (decay:%f)" %
                    (lrSchedulerOpt["decay"]))
        lrSched = ExponentialLR(optimizer, lrSchedulerOpt["decay"])
    elif lrSchedulerOpt["type"] == 'linear':
        logger.info(
            "Scheduler: Divide by (1 + epoch * decay) ---- (decay:%f)" %
            (lrSchedulerOpt["decay"]))

        lrDecay = lrSchedulerOpt["decay"]
        lrSched = LambdaLR(optimizer, lambda epoch: 1 /
                           (1.0 + epoch * lrDecay))
    else:
        raise ArgumentError(
            "LR Scheduler is invalid (%s). You should choose one of these: step, exp and linear "
            % pairGenType)

    # Set training functions
    def trainingIteration(engine, batch):
        engine.kk = 0
        model.train()

        optimizer.zero_grad()
        x, y = cmp_collate.to(batch, device)
        output = model(*x)
        loss = lossFn(output, y)
        loss.backward()
        optimizer.step()
        return loss, output, y

    def scoreDistanceTrans(output):
        if len(output) == 3:
            _, y_pred, y = output
        else:
            y_pred, y = output

        if lossFn == F.nll_loss:
            return torch.exp(y_pred[:, 1]), y
        elif isinstance(lossFn, (BCELoss)):
            return y_pred, y

    trainer = Engine(trainingIteration)
    trainingMetrics = {'training_loss': AverageLoss(lossFn)}

    if isinstance(lossFn, BCELoss):
        trainingMetrics['training_dist_target'] = MeanScoreDistance(
            output_transform=scoreDistanceTrans)
        trainingMetrics['training_acc'] = AccuracyWrapper(
            output_transform=thresholded_output_transform)
        trainingMetrics['training_precision'] = PrecisionWrapper(
            output_transform=thresholded_output_transform)
        trainingMetrics['training_recall'] = RecallWrapper(
            output_transform=thresholded_output_transform)
        # Add metrics to trainer
    for name, metric in trainingMetrics.items():
        metric.attach(trainer, name)

    # Set validation functions
    def validationIteration(engine, batch):
        if not hasattr(engine, 'kk'):
            engine.kk = 0

        model.eval()

        with torch.no_grad():
            x, y = cmp_collate.to(batch, device)
            y_pred = model(*x)

            return y_pred, y

    validationMetrics = {
        'validation_loss':
        LossWrapper(lossFn,
                    output_transform=lambda x: (x[0], x[0][0])
                    if x[1] is None else x)
    }

    if isinstance(lossFn, BCELoss):
        validationMetrics['validation_dist_target'] = MeanScoreDistance(
            output_transform=scoreDistanceTrans)
        validationMetrics['validation_acc'] = AccuracyWrapper(
            output_transform=thresholded_output_transform)
        validationMetrics['validation_precision'] = PrecisionWrapper(
            output_transform=thresholded_output_transform)
        validationMetrics['validation_recall'] = RecallWrapper(
            output_transform=thresholded_output_transform)

    evaluator = Engine(validationIteration)

    # Add metrics to evaluator
    for name, metric in validationMetrics.items():
        metric.attach(evaluator, name)

    # recommendation
    recommendation_fn = generateRecommendationList

    @trainer.on(Events.EPOCH_STARTED)
    def onStartEpoch(engine):
        epoch = engine.state.epoch
        logger.info("Epoch: %d" % epoch)

        if lrSched:
            lrSched.step()

        logger.info("LR: %s" % str(optimizer.param_groups[0]["lr"]))

    @trainer.on(Events.EPOCH_COMPLETED)
    def onEndEpoch(engine):
        epoch = engine.state.epoch

        logMetrics(_run, logger, engine.state.metrics, epoch)

        # Evaluate Training
        if validationLoader:
            evaluator.run(validationLoader)
            logMetrics(_run, logger, evaluator.state.metrics, epoch)

        lastEpoch = args['epochs'] - epoch == 0

        if recallEstimationTrainOpt and (epoch % args['rr_train_epoch'] == 0):
            logRankingResult(_run,
                             logger,
                             preselectListRankingTrain,
                             rankingScorer,
                             bugReportDatabase,
                             None,
                             epoch,
                             "train",
                             recommendationListfn=recommendation_fn)
            rankingScorer.free()

        if recallEstimationOpt and (epoch % args['rr_val_epoch'] == 0):
            logRankingResult(_run,
                             logger,
                             preselectListRanking,
                             rankingScorer,
                             bugReportDatabase,
                             args.get("ranking_result_file"),
                             epoch,
                             "validation",
                             recommendationListfn=recommendation_fn)
            rankingScorer.free()

        if not lastEpoch:
            training_reader.sampleNewNegExamples(model, lossNoReduction)

        if args.get('save'):
            save_by_epoch = args['save_by_epoch']

            if save_by_epoch and epoch in save_by_epoch:
                file_name, file_extension = os.path.splitext(args['save'])
                file_path = file_name + '_epoch_{}'.format(
                    epoch) + file_extension
            else:
                file_path = args['save']

            modelInfo = {
                'model': model.state_dict(),
                'params': parametersToSave
            }

            logger.info("==> Saving Model: %s" % file_path)
            torch.save(modelInfo, file_path)

    if args.get('pairs_training'):
        trainer.run(trainingLoader, max_epochs=args['epochs'])
    elif args.get('pairs_validation'):
        # Evaluate Training
        evaluator.run(validationLoader)
        logMetrics(_run, logger, evaluator.state.metrics, 0)

        if recallEstimationOpt:
            logRankingResult(_run,
                             logger,
                             preselectListRanking,
                             rankingScorer,
                             bugReportDatabase,
                             args.get("ranking_result_file"),
                             0,
                             "validation",
                             recommendationListfn=recommendation_fn)

    # Test Dataset (accuracy, recall, precision, F1)
    pair_test_dataset = args.get('pair_test_dataset')

    if pair_test_dataset is not None and len(pair_test_dataset) > 0:
        pairTestReader = PairBugDatasetReader(pair_test_dataset, preprocessors)
        testLoader = DataLoader(pairTestReader,
                                batch_size=batchSize,
                                collate_fn=cmp_collate.collate)

        if not isinstance(cmp_collate, PairBugCollate):
            raise NotImplementedError(
                'Evaluation of pairs using tanh was not implemented yet')

        logger.info("Test size: %s" % (len(testLoader.dataset)))

        testMetrics = {
            'test_accuracy':
            ignite.metrics.Accuracy(
                output_transform=thresholded_output_transform),
            'test_precision':
            ignite.metrics.Precision(
                output_transform=thresholded_output_transform),
            'test_recall':
            ignite.metrics.Recall(
                output_transform=thresholded_output_transform),
            'test_predictions':
            PredictionCache(),
        }
        test_evaluator = Engine(validationIteration)

        # Add metrics to evaluator
        for name, metric in testMetrics.items():
            metric.attach(test_evaluator, name)

        test_evaluator.run(testLoader)

        for metricName, metricValue in test_evaluator.state.metrics.items():
            metric = testMetrics[metricName]

            if isinstance(metric, ignite.metrics.Accuracy):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'value': metricValue,
                    'epoch': None,
                    'correct': metric._num_correct,
                    'total': metric._num_examples
                })
                _run.log_scalar(metricName, metricValue)
            elif isinstance(metric,
                            (ignite.metrics.Precision, ignite.metrics.Recall)):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'value': metricValue,
                    'epoch': None,
                    'tp': metric._true_positives.item(),
                    'total_positive': metric._positives.item()
                })
                _run.log_scalar(metricName, metricValue)
            elif isinstance(metric, ConfusionMatrix):
                acc = cmAccuracy(metricValue)
                prec = cmPrecision(metricValue, False)
                recall = cmRecall(metricValue, False)
                f1 = 2 * (prec * recall) / (prec + recall + 1e-15)

                logger.info({
                    'type':
                    'metric',
                    'label':
                    metricName,
                    'accuracy':
                    np.float(acc),
                    'precision':
                    prec.cpu().numpy().tolist(),
                    'recall':
                    recall.cpu().numpy().tolist(),
                    'f1':
                    f1.cpu().numpy().tolist(),
                    'confusion_matrix':
                    metricValue.cpu().numpy().tolist(),
                    'epoch':
                    None
                })

                _run.log_scalar('test_f1', f1[1])
            elif isinstance(metric, PredictionCache):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'predictions': metric.predictions
                })

    # Calculate recall rate
    recallRateOpt = args.get('recall_rate', {'type': 'none'})
    if recallRateOpt['type'] != 'none':
        if recallRateOpt['type'] == 'sun2011':
            logger.info("Calculating recall rate: {}".format(
                recallRateOpt['type']))
            recallRateDataset = BugDataset(recallRateOpt['dataset'])

            rankingClass = SunRanking(bugReportDatabase, recallRateDataset,
                                      recallRateOpt['window'])
            # We always group all bug reports by master in the results in the sun 2011 methodology
            group_by_master = True
        elif recallRateOpt['type'] == 'deshmukh':
            logger.info("Calculating recall rate: {}".format(
                recallRateOpt['type']))
            recallRateDataset = BugDataset(recallRateOpt['dataset'])
            rankingClass = DeshmukhRanking(bugReportDatabase,
                                           recallRateDataset)
            group_by_master = recallRateOpt['group_by_master']
        else:
            raise ArgumentError(
                "recall_rate.type is invalid (%s). You should choose one of these: step, exp and linear "
                % recallRateOpt['type'])

        logRankingResult(_run,
                         logger,
                         rankingClass,
                         rankingScorer,
                         bugReportDatabase,
                         recallRateOpt["result_file"],
                         0,
                         None,
                         group_by_master,
                         recommendationListfn=recommendation_fn)
Beispiel #27
0
 def __init__(self):
     super().__init__()
     self.sig = Sigmoid()
     self.loss = BCELoss(reduction='sum')
Beispiel #28
0
    def forward(
        self,
        content_embedding,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)

        content = self.content_dense(content_embedding)

        cat1 = pooled_output + content
        cat2 = pooled_output * content
        cat3 = pooled_output - content
        cat4 = pooled_output / content

        out = torch.cat([pooled_output, content, cat1, cat2, cat3, cat4],
                        dim=-1)

        logits = self.classifier(out)
        logits = self.act(logits)

        loss = None
        if labels is not None:
            loss_fct = BCELoss()
            loss = loss_fct(logits.view(-1), labels.view(-1))

        if not return_dict:
            output = (logits, ) + outputs[2:]
            return ((loss, ) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Beispiel #29
0
                           shuffle=False)
    testloader = DataLoader(TorchDataset(data.X_test, data.y_test),
                            batch_size=BATCH_SIZE * 2,
                            num_workers=16,
                            shuffle=False)

    model = LTBinaryClassifier(TREE_DEPTH, data.X_train.shape[1], reg=REG)

    # init optimizer
    optimizer = QHAdam(model.parameters(),
                       lr=LR,
                       nus=(0.7, 1.0),
                       betas=(0.995, 0.998))

    # init loss
    loss = BCELoss(reduction="sum")
    criterion = lambda x, y: loss(x.float(), y.float())

    # evaluation criterion => error rate
    eval_criterion = lambda x, y: (x != y).sum()

    # init train-eval monitoring
    monitor = MonitorTree(pruning, save_dir)

    state = {
        'batch-size': BATCH_SIZE,
        'loss-function': 'BCE',
        'learning-rate': LR,
        'seed': SEED,
        'dataset': DATA_NAME,
        'reg': REG,
Beispiel #30
0
def main(args):
    """
    :param n_trials: int specifying number of random train/test splits to use
    :param test_set_size: float in range [0, 1] specifying fraction of dataset to use as test set
    """

    df3 = pd.read_csv('data/rest_activity.smi')

    roc_list = []
    prc_list = []

    for i in range(args.n_trials):
        writer = SummaryWriter('runs/' + args.savename + '/run_' + str(i))

        if args.test:
            df3_train, df3_test = train_test_split(
                df3,
                stratify=df3['activity'],
                test_size=args.test_set_size,
                shuffle=True,
                random_state=i + 5)

            X3_high_train, X3_low_train, y3_train, _, _ = return_pairs(
                df3_train)
            X3_high_test, X3_low_test, y3_test, n_feats, e_feats = return_pairs(
                df3_test, test=False)

            X_high_train = X3_high_train
            X_low_train = X3_low_train
            # X_high_test = np.concatenate([X1_high_test,X2_high_test, X3_high_test])
            # X_low_test = np.concatenate([X1_low_test,X2_low_test, X3_low_test])

            y_train = y3_train
            # y_test = np.concatenate([y1_test, y2_test, y3_test])

            y_train = torch.Tensor(y_train)
            y1_test = torch.Tensor(y1_test)
            y2_test = torch.Tensor(y2_test)
            y3_test = torch.Tensor(y3_test)

            train_data = list(zip(X_high_train, X_low_train, y_train))
            test_rest = list(zip(X3_high_test, X3_low_test, y3_test))

            train_loader = DataLoader(train_data,
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=collate,
                                      drop_last=False)
            test_rest = DataLoader(test_rest,
                                   batch_size=32,
                                   shuffle=True,
                                   collate_fn=collate,
                                   drop_last=False)
        else:
            X3_high, X3_low, y3, n_feats, e_feats = return_pairs(df3)

            X_high = X3_high
            X_low = X3_low

            y = y3

            y = torch.Tensor(y)

            train_data = list(zip(X_high, X_low, y))

            train_loader = DataLoader(train_data,
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=collate,
                                      drop_last=False)

        n_tasks = 1
        mpnn_net = MPNNPairPredictor(node_in_feats=n_feats,
                                     edge_in_feats=e_feats,
                                     node_out_feats=128,
                                     n_tasks=n_tasks)
        mpnn_net = mpnn_net.to(device)

        class_loss_fn = BCELoss()

        optimizer = torch.optim.Adam(mpnn_net.parameters(), lr=args.lr)

        for epoch in tqdm(range(1, args.n_epochs + 1)):
            epoch_loss = 0
            preds = []
            labs = []
            mpnn_net.train()
            n = 0
            for i, (bg_high, bg_low, labels) in tqdm(enumerate(train_loader)):
                labels = labels.to(device)
                atom_feats_high = bg_high.ndata.pop('h').to(device)
                bond_feats_high = bg_high.edata.pop('e').to(device)
                atom_feats_low = bg_low.ndata.pop('h').to(device)
                bond_feats_low = bg_low.edata.pop('e').to(device)
                y_pred = mpnn_net(bg_high, atom_feats_high, bond_feats_high,
                                  bg_low, atom_feats_low, bond_feats_low)
                # y_pred = F.softmax(y_pred, dim=0)
                #y_pred = torch.exp(y_pred)
                y_pred = F.sigmoid(y_pred)

                loss = torch.tensor(0)
                loss = loss.to(device)

                if args.debug:
                    print('label: {}'.format(labels))
                    print('y_pred: {}'.format(y_pred))

                loss = loss + class_loss_fn(y_pred, labels)
                if args.debug:
                    print('loss: {}'.format(loss))
                optimizer.zero_grad()
                loss.backward()

                optimizer.step()
                epoch_loss += loss.detach().item()

                labels = labels.cpu().numpy()
                y_pred = y_pred.detach().cpu().numpy()

                # store labels and preds
                preds.append(y_pred)
                labs.append(labels)

            labs = np.concatenate(labs, axis=0)
            preds = np.concatenate(preds, axis=0)

            roc = roc_auc_score(labs, preds)
            precision, recall, thresholds = precision_recall_curve(labs, preds)
            prc = auc(recall, precision)
            if args.debug:
                print('roc: {}'.format(roc))
                print('prc: {}'.format(prc))
            writer.add_scalar('LOSS/train', epoch_loss, epoch)
            writer.add_scalar('train/pair_rocauc', roc, epoch)
            writer.add_scalar('train/pair_prcauc', prc, epoch)

            # if epoch % 20 == 0:
            #     print(f"\nepoch: {epoch}, "
            #           f"LOSS: {epoch_loss:.3f}"
            #           f"\n pair ROC-AUC: {roc:.3f}, "
            #           f"pair PRC-AUC: {prc:.3f}")
            #
            #     try:
            #         torch.save(mpnn_net.state_dict(), '/rds-d2/user/wjm41/hpc-work/models/' + args.savename +
            #                    '/model_epoch_' + str(epoch) + '.pt')
            #     except FileNotFoundError:
            #         cmd = 'mkdir /rds-d2/user/wjm41/hpc-work/models/' + args.savename
            #         os.system(cmd)
            #         torch.save(mpnn_net.state_dict(), '/rds-d2/user/wjm41/hpc-work/models/' + args.savename +
            #                    '/model_epoch_' + str(epoch) + '.pt')
            if args.test:
                mpnn_net.eval()
                preds = []
                labs = []
                for i, (bg_high, bg_low, labels) in enumerate(test_rest):
                    labels = labels.to(device)
                    atom_feats_high = bg_high.ndata.pop('h').to(device)
                    bond_feats_high = bg_high.edata.pop('e').to(device)
                    atom_feats_low = bg_low.ndata.pop('h').to(device)
                    bond_feats_low = bg_low.edata.pop('e').to(device)
                    y_pred = mpnn_net(bg_high, atom_feats_high,
                                      bond_feats_high, bg_low, atom_feats_low,
                                      bond_feats_low)
                    y_pred = F.sigmoid(y_pred)

                    labels = labels.cpu().numpy()
                    y_pred = y_pred.detach().cpu().numpy()

                    preds.append(y_pred)
                    labs.append(labels)

                labs = np.concatenate(labs, axis=0)
                preds = np.concatenate(preds, axis=0)

                roc = roc_auc_score(labs, preds)
                precision, recall, thresholds = precision_recall_curve(
                    labs, preds)
                prc = auc(recall, precision)

                writer.add_scalar('test/pair_rest_rocauc', roc, epoch)
                writer.add_scalar('test/pair_rest_prcauc', prc, epoch)

                if epoch == (args.n_epochs):
                    print(
                        f"\n======================== TEST ========================"
                        f"\n pair ROC-AUC: {roc:.3f}, "
                        f"pair PRC-AUC: {prc:.3f}")

                    roc_list.append(roc)
                    prc_list.append(prc)

        torch.save(
            mpnn_net.state_dict(), '/rds-d2/user/wjm41/hpc-work/models/' +
            args.savename + '/model_epoch_final.pt')
    if args.test:
        roc_list = np.array(roc_list).T
        prc_list = np.array(prc_list).T

        print("\n TEST")
        print("ROC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(roc_list[0]),
            np.std(roc_list[0]) / np.sqrt(len(roc_list[0]))))
        print("PRC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(prc_list[0]),
            np.std(prc_list[0]) / np.sqrt(len(prc_list[0]))))