예제 #1
0
def prepare_optimisers(args, logger, policy_parameters,
                       environment_parameters):
    if args.optimizer == "adam":
        optimizer_class = torch.optim.Adam
    elif args.optimizer == "adadelta":
        optimizer_class = torch.optim.Adadelta
    else:
        optimizer_class = torch.optim.SGD
    optimizer = {
        "policy":
        optimizer_class(params=policy_parameters,
                        lr=args.pol_lr,
                        weight_decay=args.l2_weight),
        "env":
        optimizer_class(params=environment_parameters,
                        lr=args.env_lr,
                        weight_decay=args.l2_weight)
    }
    lr_scheduler = {
        "policy":
        get_lr_scheduler(logger,
                         optimizer["policy"],
                         patience=args.lr_scheduler_patience),
        "env":
        get_lr_scheduler(logger,
                         optimizer["env"],
                         patience=args.lr_scheduler_patience)
    }
    es = EarlyStopping(mode="max",
                       patience=args.es_patience,
                       threshold=args.es_threshold)
    return optimizer, lr_scheduler, es
예제 #2
0
    def __init__(self, input_nc=3, output_nc=3, gpu_id=None):
        self.device = torch.device(
            f"cuda:{gpu_id}" if gpu_id is not None else 'cpu')
        print(f"Using device {self.device}")

        # Hyperparameters
        self.lambda_idt = 0.5
        self.lambda_A = 10.0
        self.lambda_B = 10.0

        # Define generator networks
        self.netG_A = networks.define_netG(input_nc,
                                           output_nc,
                                           ngf=64,
                                           n_blocks=9,
                                           device=self.device)
        self.netG_B = networks.define_netG(output_nc,
                                           input_nc,
                                           ngf=64,
                                           n_blocks=9,
                                           device=self.device)

        # Define discriminator networks
        self.netD_A = networks.define_netD(output_nc,
                                           ndf=64,
                                           n_layers=3,
                                           device=self.device)
        self.netD_B = networks.define_netD(input_nc,
                                           ndf=64,
                                           n_layers=3,
                                           device=self.device)

        # Define image pools
        self.fake_A_pool = utils.ImagePool(pool_size=50)
        self.fake_B_pool = utils.ImagePool(pool_size=50)

        # Define loss functions
        self.criterionGAN = networks.GANLoss().to(self.device)
        self.criterionCycle = torch.nn.L1Loss()
        self.criterionIdt = torch.nn.L1Loss()

        # Define optimizers
        netG_params = itertools.chain(self.netG_A.parameters(),
                                      self.netG_B.parameters())
        netD_params = itertools.chain(self.netD_A.parameters(),
                                      self.netD_B.parameters())
        self.optimizer_G = torch.optim.Adam(netG_params,
                                            lr=0.0002,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(netD_params,
                                            lr=0.0002,
                                            betas=(0.5, 0.999))

        # Learning rate schedulers
        self.scheduler_G = utils.get_lr_scheduler(self.optimizer_G)
        self.scheduler_D = utils.get_lr_scheduler(self.optimizer_D)
예제 #3
0
 def __init__(self,perparameters):
     super(UNIT_Gender_Trainer,self).__init__()
     lr = perparameters['lr']
     self.gen_a = VAEGen(perparameters['input_dim_a'],perparameters['gen'])
     self.gen_b = VAEGen(perparameters['input_dim_b'],perparameters['gen'])
     self.dis_a = MsImageDis(perparameters['input_dim_a'],perparameters['dis'])
     self.dis_b = MsImageDis(perparameters['input_dim_b'],perparameters['dis'])
     dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
     gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
     self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                     lr=lr)
     self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                     lr=lr)
     self.dis_scheduler = get_lr_scheduler(self.dis_opt, perparameters)
     self.gen_scheduler = get_lr_scheduler(self.gen_opt, perparameters)
     self.apply(weights_init(perparameters['init']))
예제 #4
0
    def __init__(self, config, storage, replay_buffer, state=None):
        set_all_seeds(config.seed)

        self.run_tag = config.run_tag
        self.group_tag = config.group_tag
        self.worker_id = 'learner'
        self.replay_buffer = replay_buffer
        self.storage = storage
        self.config = deepcopy(config)

        if "learner" in self.config.use_gpu_for:
            if torch.cuda.is_available():
                if self.config.learner_gpu_device_id is not None:
                    device_id = self.config.learner_gpu_device_id
                    self.device = torch.device("cuda:{}".format(device_id))
                else:
                    self.device = torch.device("cuda")
            else:
                raise RuntimeError(
                    "GPU was requested but torch.cuda.is_available() is False."
                )
        else:
            self.device = torch.device("cpu")

        self.network = get_network(config, self.device)
        self.network.to(self.device)
        self.network.train()

        self.optimizer = get_optimizer(config, self.network.parameters())
        self.lr_scheduler = get_lr_scheduler(config, self.optimizer)
        self.scalar_loss_fn, self.policy_loss_fn = get_loss_functions(config)

        self.training_step = 0
        self.losses_to_log = {'reward': 0., 'value': 0., 'policy': 0.}

        self.throughput = {
            'total_frames': 0,
            'total_games': 0,
            'training_step': 0,
            'time': {
                'ups': 0,
                'fps': 0
            }
        }

        if self.config.norm_obs:
            self.obs_min = np.array(self.config.obs_range[::2],
                                    dtype=np.float32)
            self.obs_max = np.array(self.config.obs_range[1::2],
                                    dtype=np.float32)
            self.obs_range = self.obs_max - self.obs_min

        if state is not None:
            self.load_state(state)

        Logger.__init__(self)
    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device("cuda" if opt.ngpu else "cpu")
        
        self.model, self.classifier = models.get_model(opt.net_type, 
                                                       opt.loss_type, 
                                                       opt.pretrained,
                                                       int(opt.nclasses))
        self.model = self.model.to(self.device)
        self.classifier = self.classifier.to(self.device)

        if opt.ngpu>1:
            self.model = nn.DataParallel(self.model)
            
        self.loss = models.init_loss(opt.loss_type)
        self.loss = self.loss.to(self.device)

        self.optimizer = utils.get_optimizer(self.model, self.opt)
        self.lr_scheduler = utils.get_lr_scheduler(self.opt, self.optimizer)

        self.train_loader = datasets.generate_loader(opt,'train') 
        self.test_loader = datasets.generate_loader(opt,'val')    
        
        self.epoch = 0
        self.best_epoch = False
        self.training = False
        self.state = {}
        

        self.train_loss = utils.AverageMeter()
        self.test_loss  = utils.AverageMeter()
        self.batch_time = utils.AverageMeter()
        if self.opt.loss_type in ['cce', 'bce', 'mse', 'arc_margin']:
            self.test_metrics = utils.AverageMeter()
        else:
            self.test_metrics = utils.ROCMeter()

        self.best_test_loss = utils.AverageMeter()                    
        self.best_test_loss.update(np.array([np.inf]))

        self.visdom_log_file = os.path.join(self.opt.out_path, 'log_files', 'visdom.log')
        self.vis = Visdom(port = opt.visdom_port,
                          log_to_filename=self.visdom_log_file,
                          env=opt.exp_name + '_' + str(opt.fold))

        self.vis_loss_opts = {'xlabel': 'epoch', 
                              'ylabel': 'loss', 
                              'title':'losses', 
                              'legend': ['train_loss', 'val_loss']}

        self.vis_epochloss_opts = {'xlabel': 'epoch', 
                              'ylabel': 'loss', 
                              'title':'epoch_losses', 
                              'legend': ['train_loss', 'val_loss']}
예제 #6
0
 def resume(self, checkpoint_dir, hyperparameters):
     # Load generators
     last_model_name = get_model_list(checkpoint_dir, "gen")
     state_dict = torch.load(last_model_name)
     self.gen_a.load_state_dict(state_dict['a'])
     self.gen_b.load_state_dict(state_dict['b'])
     iterations = int(last_model_name[-11:-3])
     # Load discriminators
     last_model_name = get_model_list(checkpoint_dir, "dis")
     state_dict = torch.load(last_model_name)
     self.dis_a.load_state_dict(state_dict['a'])
     self.dis_b.load_state_dict(state_dict['b'])
     # Load optimizers
     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
     self.dis_opt.load_state_dict(state_dict['dis'])
     self.gen_opt.load_state_dict(state_dict['gen'])
     # Reinitilize schedulers
     self.dis_scheduler = get_lr_scheduler(self.dis_opt, hyperparameters, iterations)
     self.gen_scheduler = get_lr_scheduler(self.gen_opt, hyperparameters, iterations)
     print('Resume from iteration %d' % iterations)
     return iterations
예제 #7
0
def prepare_optimiser(args, logger, parameters):
    if args.optimizer == "adam":
        optimizer_class = torch.optim.Adam
    elif args.optimizer == "amsgrad":
        optimizer_class = partial(torch.optim.Adam, amsgrad=True)
    elif args.optimizer == "adadelta":
        optimizer_class = torch.optim.Adadelta
    else:
        optimizer_class = torch.optim.SGD
    optimizer = optimizer_class(params=parameters,
                                lr=args.lr,
                                weight_decay=args.l2_weight)
    lr_scheduler = get_lr_scheduler(logger,
                                    optimizer,
                                    patience=args.lr_scheduler_patience,
                                    threshold=args.lr_scheduler_threshold)
    es = EarlyStopping(mode="max",
                       patience=args.es_patience,
                       threshold=args.es_threshold)
    return optimizer, lr_scheduler, es
예제 #8
0
def main(args):

    # Device Configuration #
    device = torch.device(
        f'cuda:{args.gpu_num}' if torch.cuda.is_available() else 'cpu')

    # Fix Seed for Reproducibility #
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Samples, Plots, Weights and CSV Path #
    paths = [
        args.samples_path, args.weights_path, args.csv_path,
        args.inference_path
    ]
    for path in paths:
        make_dirs(path)

    # Prepare Data #
    data = pd.read_csv(args.data_path)[args.column]

    # Prepare Data #
    scaler_1 = StandardScaler()
    scaler_2 = StandardScaler()
    preprocessed_data = pre_processing(data, scaler_1, scaler_2, args.constant,
                                       args.delta)

    train_X, train_Y, test_X, test_Y = prepare_data(data, preprocessed_data,
                                                    args)

    train_X = moving_windows(train_X, args.ts_dim)
    train_Y = moving_windows(train_Y, args.ts_dim)

    test_X = moving_windows(test_X, args.ts_dim)
    test_Y = moving_windows(test_Y, args.ts_dim)

    # Prepare Networks #
    if args.model == 'conv':
        D = ConvDiscriminator(args.ts_dim).to(device)
        G = ConvGenerator(args.latent_dim, args.ts_dim).to(device)

    elif args.model == 'lstm':
        D = LSTMDiscriminator(args.ts_dim).to(device)
        G = LSTMGenerator(args.latent_dim, args.ts_dim).to(device)

    else:
        raise NotImplementedError

    #########
    # Train #
    #########

    if args.mode == 'train':

        # Loss Function #
        if args.criterion == 'l2':
            criterion = nn.MSELoss()

        elif args.criterion == 'wgangp':
            pass

        else:
            raise NotImplementedError

        # Optimizers #
        if args.optim == 'sgd':
            D_optim = torch.optim.SGD(D.parameters(), lr=args.lr, momentum=0.9)
            G_optim = torch.optim.SGD(G.parameters(), lr=args.lr, momentum=0.9)

        elif args.optim == 'adam':
            D_optim = torch.optim.Adam(D.parameters(),
                                       lr=args.lr,
                                       betas=(0., 0.9))
            G_optim = torch.optim.Adam(G.parameters(),
                                       lr=args.lr,
                                       betas=(0., 0.9))

        else:
            raise NotImplementedError

        D_optim_scheduler = get_lr_scheduler(D_optim, args)
        G_optim_scheduler = get_lr_scheduler(G_optim, args)

        # Lists #
        D_losses, G_losses = list(), list()

        # Train #
        print(
            "Training Time Series GAN started with total epoch of {}.".format(
                args.num_epochs))

        for epoch in range(args.num_epochs):

            # Initialize Optimizers #
            G_optim.zero_grad()
            D_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            if args.criterion == 'l2':
                n_critics = 1
            elif args.criterion == 'wgangp':
                n_critics = 5

            for j in range(n_critics):
                series, start_dates = get_samples(train_X, train_Y,
                                                  args.batch_size)

                # Data Preparation #
                series = series.to(device)
                noise = torch.randn(args.batch_size, 1,
                                    args.latent_dim).to(device)

                # Adversarial Loss using Real Image #
                prob_real = D(series.float())

                if args.criterion == 'l2':
                    real_labels = torch.ones(prob_real.size()).to(device)
                    D_real_loss = criterion(prob_real, real_labels)

                elif args.criterion == 'wgangp':
                    D_real_loss = -torch.mean(prob_real)

                # Adversarial Loss using Fake Image #
                fake_series = G(noise)
                prob_fake = D(fake_series.detach())

                if args.criterion == 'l2':
                    fake_labels = torch.zeros(prob_fake.size()).to(device)
                    D_fake_loss = criterion(prob_fake, fake_labels)

                elif args.criterion == 'wgangp':
                    D_fake_loss = torch.mean(prob_fake)
                    D_gp_loss = args.lambda_gp * get_gradient_penalty(
                        D, series.float(), fake_series.float(), device)

                # Calculate Total Discriminator Loss #
                D_loss = D_fake_loss + D_real_loss

                if args.criterion == 'wgangp':
                    D_loss += args.lambda_gp * D_gp_loss

                # Back Propagation and Update #
                D_loss.backward()
                D_optim.step()

            ###################
            # Train Generator #
            ###################

            # Adversarial Loss #
            fake_series = G(noise)
            prob_fake = D(fake_series)

            # Calculate Total Generator Loss #
            if args.criterion == 'l2':
                real_labels = torch.ones(prob_fake.size()).to(device)
                G_loss = criterion(prob_fake, real_labels)

            elif args.criterion == 'wgangp':
                G_loss = -torch.mean(prob_fake)

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            # Adjust Learning Rate #
            D_optim_scheduler.step()
            G_optim_scheduler.step()

            # Print Statistics, Save Model Weights and Series #
            if (epoch + 1) % args.log_every == 0:

                # Print Statistics and Save Model #
                print("Epochs [{}/{}] | D Loss {:.4f} | G Loss {:.4f}".format(
                    epoch + 1, args.num_epochs, np.average(D_losses),
                    np.average(G_losses)))
                torch.save(
                    G.state_dict(),
                    os.path.join(
                        args.weights_path,
                        'TS_using{}_and_{}_Epoch_{}.pkl'.format(
                            G.__class__.__name__, args.criterion.upper(),
                            epoch + 1)))

                # Generate Samples and Save Plots and CSVs #
                series, fake_series = generate_fake_samples(
                    test_X, test_Y, G, scaler_1, scaler_2, args, device)
                plot_series(series, fake_series, G, epoch, args,
                            args.samples_path)
                make_csv(series, fake_series, G, epoch, args, args.csv_path)

    ########
    # Test #
    ########

    elif args.mode == 'test':

        # Load Model Weights #
        G.load_state_dict(
            torch.load(
                os.path.join(
                    args.weights_path, 'TS_using{}_and_{}_Epoch_{}.pkl'.format(
                        G.__class__.__name__, args.criterion.upper(),
                        args.num_epochs))))

        # Lists #
        real, fake = list(), list()

        # Inference #
        for idx in range(0, test_X.shape[0], args.ts_dim):

            # Do not plot if the remaining data is less than time dimension #
            end_ix = idx + args.ts_dim

            if end_ix > len(test_X) - 1:
                break

            # Prepare Data #
            test_data = test_X[idx, :]
            test_data = np.expand_dims(test_data, axis=0)
            test_data = np.expand_dims(test_data, axis=1)
            test_data = torch.from_numpy(test_data).to(device)

            start = test_Y[idx, 0]

            noise = torch.randn(args.val_batch_size, 1,
                                args.latent_dim).to(device)

            # Generate Fake Data #
            with torch.no_grad():
                fake_series = G(noise)

            # Convert to Numpy format for Saving #
            test_data = np.squeeze(test_data.cpu().data.numpy())
            fake_series = np.squeeze(fake_series.cpu().data.numpy())

            test_data = post_processing(test_data, start, scaler_1, scaler_2,
                                        args.delta)
            fake_series = post_processing(fake_series, start, scaler_1,
                                          scaler_2, args.delta)

            real += test_data.tolist()
            fake += fake_series.tolist()

        # Plot, Save to CSV file and Derive Metrics #
        plot_series(real, fake, G, args.num_epochs - 1, args,
                    args.inference_path)
        make_csv(real, fake, G, args.num_epochs - 1, args, args.inference_path)
        derive_metrics(real, fake, args)

    else:
        raise NotImplementedError
def main(config):

    # Fix Seed #
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)

    # Weights and Plots Path #
    paths = [config.weights_path, config.plots_path]

    for path in paths:
        make_dirs(path)

    # Prepare Data #
    data = load_data(config.which_data)[[config.feature]]
    data = data.copy()

    # Plot Time-Series Data #
    if config.plot_full:
        plot_full(config.plots_path, data, config.feature)

    scaler = MinMaxScaler()
    data[config.feature] = scaler.fit_transform(data)

    train_loader, val_loader, test_loader = \
        data_loader(data, config.seq_length, config.train_split, config.test_split, config.batch_size)

    # Lists #
    train_losses, val_losses = list(), list()
    val_maes, val_mses, val_rmses, val_mapes, val_mpes, val_r2s = list(), list(), list(), list(), list(), list()
    test_maes, test_mses, test_rmses, test_mapes, test_mpes, test_r2s = list(), list(), list(), list(), list(), list()

    # Constants #
    best_val_loss = 100
    best_val_improv = 0

    # Prepare Network #
    if config.network == 'dnn':
        model = DNN(config.seq_length, config.hidden_size, config.output_size).to(device)
    elif config.network == 'cnn':
        model = CNN(config.seq_length, config.batch_size).to(device)
    elif config.network == 'rnn':
        model = RNN(config.input_size, config.hidden_size, config.num_layers, config.output_size).to(device)
    elif config.network == 'lstm':
        model = LSTM(config.input_size, config.hidden_size, config.num_layers, config.output_size, config.bidirectional).to(device)
    elif config.network == 'gru':
        model = GRU(config.input_size, config.hidden_size, config.num_layers, config.output_size).to(device)
    elif config.network == 'recursive':
        model = RecursiveLSTM(config.input_size, config.hidden_size, config.num_layers, config.output_size).to(device)
    elif config.network == 'attention':
        model = AttentionLSTM(config.input_size, config.key, config.query, config.value, config.hidden_size, config.num_layers, config.output_size, config.bidirectional).to(device)
    else:
        raise NotImplementedError

    # Loss Function #
    criterion = torch.nn.MSELoss()

    # Optimizer #
    optim = torch.optim.Adam(model.parameters(), lr=config.lr, betas=(0.5, 0.999))
    optim_scheduler = get_lr_scheduler(config.lr_scheduler, optim)

    # Train and Validation #
    if config.mode == 'train':

        # Train #
        print("Training {} started with total epoch of {}.".format(model.__class__.__name__, config.num_epochs))

        for epoch in range(config.num_epochs):
            for i, (data, label) in enumerate(train_loader):

                # Prepare Data #
                data = data.to(device, dtype=torch.float32)
                label = label.to(device, dtype=torch.float32)

                # Forward Data #
                pred = model(data)

                # Calculate Loss #
                train_loss = criterion(pred, label)

                # Initialize Optimizer, Back Propagation and Update #
                optim.zero_grad()
                train_loss.backward()
                optim.step()

                # Add item to Lists #
                train_losses.append(train_loss.item())

            # Print Statistics #
            if (epoch+1) % config.print_every == 0:
                print("Epoch [{}/{}]".format(epoch+1, config.num_epochs))
                print("Train Loss {:.4f}".format(np.average(train_losses)))

            # Learning Rate Scheduler #
            optim_scheduler.step()

            # Validation #
            with torch.no_grad():
                for i, (data, label) in enumerate(val_loader):

                    # Prepare Data #
                    data = data.to(device, dtype=torch.float32)
                    label = label.to(device, dtype=torch.float32)

                    # Forward Data #
                    pred_val = model(data)

                    # Calculate Loss #
                    val_loss = criterion(pred_val, label)
                    val_mae = mean_absolute_error(label.cpu(), pred_val.cpu())
                    val_mse = mean_squared_error(label.cpu(), pred_val.cpu(), squared=True)
                    val_rmse = mean_squared_error(label.cpu(), pred_val.cpu(), squared=False)
                    val_mpe = mean_percentage_error(label.cpu(), pred_val.cpu())
                    val_mape = mean_absolute_percentage_error(label.cpu(), pred_val.cpu())
                    val_r2 = r2_score(label.cpu(), pred_val.cpu())

                    # Add item to Lists #
                    val_losses.append(val_loss.item())
                    val_maes.append(val_mae.item())
                    val_mses.append(val_mse.item())
                    val_rmses.append(val_rmse.item())
                    val_mpes.append(val_mpe.item())
                    val_mapes.append(val_mape.item())
                    val_r2s.append(val_r2.item())

            if (epoch + 1) % config.print_every == 0:

                # Print Statistics #
                print("Val Loss {:.4f}".format(np.average(val_losses)))
                print("Val  MAE : {:.4f}".format(np.average(val_maes)))
                print("Val  MSE : {:.4f}".format(np.average(val_mses)))
                print("Val RMSE : {:.4f}".format(np.average(val_rmses)))
                print("Val  MPE : {:.4f}".format(np.average(val_mpes)))
                print("Val MAPE : {:.4f}".format(np.average(val_mapes)))
                print("Val  R^2 : {:.4f}".format(np.average(val_r2s)))

                # Save the model Only if validation loss decreased #
                curr_val_loss = np.average(val_losses)

                if curr_val_loss < best_val_loss:
                    best_val_loss = min(curr_val_loss, best_val_loss)
                    torch.save(model.state_dict(), os.path.join(config.weights_path, 'BEST_{}.pkl'.format(model.__class__.__name__)))

                    print("Best model is saved!\n")
                    best_val_improv = 0

                elif curr_val_loss >= best_val_loss:
                    best_val_improv += 1
                    print("Best Validation has not improved for {} epochs.\n".format(best_val_improv))

    elif config.mode == 'test':

        # Load the Model Weight #
        model.load_state_dict(torch.load(os.path.join(config.weights_path, 'BEST_{}.pkl'.format(model.__class__.__name__))))

        # Test #
        with torch.no_grad():
            for i, (data, label) in enumerate(test_loader):

                # Prepare Data #
                data = data.to(device, dtype=torch.float32)
                label = label.to(device, dtype=torch.float32)

                # Forward Data #
                pred_test = model(data)

                # Convert to Original Value Range #
                pred_test = pred_test.data.cpu().numpy()
                label = label.data.cpu().numpy().reshape(-1, 1)

                pred_test = scaler.inverse_transform(pred_test)
                label = scaler.inverse_transform(label)

                # Calculate Loss #
                test_mae = mean_absolute_error(label, pred_test)
                test_mse = mean_squared_error(label, pred_test, squared=True)
                test_rmse = mean_squared_error(label, pred_test, squared=False)
                test_mpe = mean_percentage_error(label, pred_test)
                test_mape = mean_absolute_percentage_error(label, pred_test)
                test_r2 = r2_score(label, pred_test)

                # Add item to Lists #
                test_maes.append(test_mae.item())
                test_mses.append(test_mse.item())
                test_rmses.append(test_rmse.item())
                test_mpes.append(test_mpe.item())
                test_mapes.append(test_mape.item())
                test_r2s.append(test_r2.item())

            # Print Statistics #
            print("Test {}".format(model.__class__.__name__))
            print("Test  MAE : {:.4f}".format(np.average(test_maes)))
            print("Test  MSE : {:.4f}".format(np.average(test_mses)))
            print("Test RMSE : {:.4f}".format(np.average(test_rmses)))
            print("Test  MPE : {:.4f}".format(np.average(test_mpes)))
            print("Test MAPE : {:.4f}".format(np.average(test_mapes)))
            print("Test  R^2 : {:.4f}".format(np.average(test_r2s)))

            # Plot Figure #
            plot_pred_test(pred_test, label, config.plots_path, config.feature, model)
예제 #10
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader('train', config.batch_size)
    val_horse_loader, val_zebra_loader = get_horse2zebra_loader('test', config.batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Image Pool #
    masked_fake_A_pool = ImageMaskPool(config.pool_size)
    masked_fake_B_pool = ImageMaskPool(config.pool_size)

    # Prepare Networks #
    Attn_A = Attention()
    Attn_B = Attention()
    G_A2B = Generator()
    G_B2A = Generator()
    D_A = Discriminator()
    D_B = Discriminator()

    networks = [Attn_A, Attn_B, G_A2B, G_B2A, D_A, D_B]
    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()), lr=config.lr, betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(Attn_A.parameters(), Attn_B.parameters(), G_A2B.parameters(), G_B2A.parameters()), lr=config.lr, betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_A_losses, D_B_losses = [], []
    G_A_losses, G_B_losses = [], []

    # Train #
    print("Training Unsupervised Attention-Guided GAN started with total epoch of {}.".format(config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(zip(train_horse_loader, train_zebra_loader)):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss using real A #
            attn_A = Attn_A(real_A)
            fake_B = G_A2B(real_A)

            masked_fake_B = fake_B * attn_A + real_A * (1-attn_A)

            masked_fake_B *= attn_A
            prob_real_A = D_A(masked_fake_B)
            real_labels = torch.ones(prob_real_A.size()).to(device)

            G_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Adversarial Loss using real B #
            attn_B = Attn_B(real_B)
            fake_A = G_B2A(real_B)

            masked_fake_A = fake_A * attn_B + real_B * (1-attn_B)

            masked_fake_A *= attn_B
            prob_real_B = D_B(masked_fake_A)
            real_labels = torch.ones(prob_real_B.size()).to(device)

            G_loss_B = criterion_Adversarial(prob_real_B, real_labels)

            # Cycle Consistency Loss using real A #
            attn_ABA = Attn_B(masked_fake_B)
            fake_ABA = G_B2A(masked_fake_B)
            masked_fake_ABA = fake_ABA * attn_ABA + masked_fake_B * (1 - attn_ABA)

            # Cycle Consistency Loss using real B #
            attn_BAB = Attn_A(masked_fake_A)
            fake_BAB = G_A2B(masked_fake_A)
            masked_fake_BAB = fake_BAB * attn_BAB + masked_fake_A * (1 - attn_BAB)

            # Cycle Consistency Loss #
            G_cycle_loss_A = config.lambda_cycle * criterion_Cycle(masked_fake_ABA, real_A)
            G_cycle_loss_B = config.lambda_cycle * criterion_Cycle(masked_fake_BAB, real_B)

            # Total Generator Loss #
            G_loss = G_loss_A + G_loss_B + G_cycle_loss_A + G_cycle_loss_B

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            # Train Discriminator A using real A #
            prob_real_A = D_A(real_B)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_loss_real_A = criterion_Adversarial(prob_real_A, real_labels)

            # Add Pooling #
            masked_fake_B, attn_A = masked_fake_B_pool.query(masked_fake_B, attn_A)
            masked_fake_B *= attn_A

            # Train Discriminator A using fake B #
            prob_fake_B = D_A(masked_fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            D_loss_fake_A = criterion_Adversarial(prob_fake_B, fake_labels)

            D_loss_A = (D_loss_real_A + D_loss_fake_A).mean()

            # Train Discriminator B using real B #
            prob_real_B = D_B(real_A)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            D_loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Add Pooling #
            masked_fake_A, attn_B = masked_fake_A_pool.query(masked_fake_A, attn_B)
            masked_fake_A *= attn_B

            # Train Discriminator B using fake A #
            prob_fake_A = D_B(masked_fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_loss_fake_B = criterion_Adversarial(prob_fake_A, fake_labels)

            D_loss_B = (D_loss_real_B + D_loss_fake_B).mean()

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            # Add items to Lists #
            D_A_losses.append(D_loss_A.item())
            D_B_losses.append(D_loss_B.item())
            G_A_losses.append(G_loss_A.item())
            G_B_losses.append(G_loss_B.item())

            ####################
            # Print Statistics #
            ####################

            if (i+1) % config.print_every == 0:
                print("UAG-GAN | Epoch [{}/{}] | Iteration [{}/{}] | D A Losses {:.4f} | D B Losses {:.4f} | G A Losses {:.4f} | G B Losses {:.4f}".
                      format(epoch+1, config.num_epochs, i+1, total_batch, np.average(D_A_losses), np.average(D_B_losses), np.average(G_A_losses), np.average(G_B_losses)))

                # Save Sample Images #
                save_samples(val_horse_loader, val_zebra_loader, G_A2B, G_B2A, Attn_A, Attn_B, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(G_A2B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_A2B_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(G_B2A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Generator_B2A_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(Attn_A.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_A_Epoch_{}.pkl'.format(epoch+1)))
            torch.save(Attn_B.state_dict(), os.path.join(config.weights_path, 'UAG-GAN_Attention_B_Epoch_{}.pkl'.format(epoch+1)))

    # Make a GIF file #
    make_gifs_train("UAG-GAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_A_losses, D_B_losses, G_A_losses, G_B_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
예제 #11
0
def main():
    args = parse_args()

    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED
    gpus = [int(i) for i in config.GPUS.split(',')]

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # initialize generator and discriminator
    G_AB = eval('models.cyclegan.get_generator')(config.DATA.IMAGE_SHAPE,
                                                 config.NETWORK.NUM_RES_BLOCKS)
    G_BA = eval('models.cyclegan.get_generator')(config.DATA.IMAGE_SHAPE,
                                                 config.NETWORK.NUM_RES_BLOCKS)
    D_A = eval('models.cyclegan.get_discriminator')(config.DATA.IMAGE_SHAPE)
    D_B = eval('models.cyclegan.get_discriminator')(config.DATA.IMAGE_SHAPE)
    #logger.info(pprint.pformat(G_AB))
    #logger.info(pprint.pformat(D_A))

    # multi-gpus

    model_dict = {}
    model_dict['G_AB'] = torch.nn.DataParallel(G_AB, device_ids=gpus).cuda()
    model_dict['G_BA'] = torch.nn.DataParallel(G_BA, device_ids=gpus).cuda()
    model_dict['D_A'] = torch.nn.DataParallel(D_A, device_ids=gpus).cuda()
    model_dict['D_B'] = torch.nn.DataParallel(D_B, device_ids=gpus).cuda()

    # loss functions
    criterion_dict = {}
    criterion_dict['GAN'] = torch.nn.MSELoss().cuda()
    criterion_dict['cycle'] = torch.nn.L1Loss().cuda()
    criterion_dict['identity'] = torch.nn.L1Loss().cuda()

    # optimizers
    optimizer_dict = {}
    optimizer_dict['G'] = get_optimizer(
        config, itertools.chain(G_AB.parameters(), G_BA.parameters()))
    optimizer_dict['D_A'] = get_optimizer(config, D_A.parameters())
    optimizer_dict['D_B'] = get_optimizer(config, D_B.parameters())

    start_epoch = config.TRAIN.START_EPOCH
    if config.TRAIN.RESUME:
        start_epoch, model_dict, optimizer_dict = load_checkpoint(
            model_dict, optimizer_dict, final_output_dir)

    # learning rate schedulers
    lr_scheduler_dict = {}
    lr_scheduler_dict['G'] = get_lr_scheduler(config, optimizer_dict['G'])
    lr_scheduler_dict['D_A'] = get_lr_scheduler(config, optimizer_dict['D_A'])
    lr_scheduler_dict['D_B'] = get_lr_scheduler(config, optimizer_dict['D_B'])
    for steps in range(start_epoch):
        for lr_scheduler in lr_scheduler_dict.values():
            lr_scheduler.step()

    #Buffers of previously generated samples
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Image transformations
    transforms_ = [
        #transforms.Resize(int(config.img_height * 1.12), Image.BICUBIC),
        #transforms.RandomCrop((config.img_height, config.img_width)),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    # Dataset
    logger.info('=> loading train and testing dataset...')

    train_dataset = ImageDataset(config.DATA.TRAIN_DATASET_B,
                                 config.DATA.TRAIN_DATASET,
                                 transforms_=transforms_)
    test_dataset = ImageDataset(config.DATA.TEST_DATASET_B,
                                config.DATA.TEST_DATASET,
                                transforms_=transforms_,
                                mode='test')
    # Training data loader
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.TRAIN.BATCH_SIZE *
                                  len(gpus),
                                  shuffle=config.TRAIN.SHUFFLE,
                                  num_workers=config.NUM_WORKERS)
    # Test data loader
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=config.TEST.BATCH_SIZE * len(gpus),
                                 shuffle=False,
                                 num_workers=config.NUM_WORKERS)

    for epoch in range(start_epoch, config.TRAIN.END_EPOCH):

        train(config, epoch, model_dict, fake_A_buffer, fake_B_buffer,
              train_dataloader, criterion_dict, optimizer_dict,
              lr_scheduler_dict, writer_dict)

        test(config, model_dict, test_dataloader, criterion_dict,
             final_output_dir)

        for lr_scheduler in lr_scheduler_dict.values():
            lr_scheduler.step()

        if config.TRAIN.CHECKPOINT_INTERVAL != -1 and epoch % config.TRAIN.CHECKPOINT_INTERVAL == 0:
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': 'cyclegan',
                    'state_dict_G_AB': model_dict['G_AB'].module.state_dict(),
                    'state_dict_G_BA': model_dict['G_BA'].module.state_dict(),
                    'state_dict_D_A': model_dict['D_A'].module.state_dict(),
                    'state_dict_D_B': model_dict['D_B'].module.state_dict(),
                    'optimizer_G': optimizer_dict['G'].state_dict(),
                    'optimizer_D_A': optimizer_dict['D_A'].state_dict(),
                    'optimizer_D_B': optimizer_dict['D_B'].state_dict(),
                }, final_output_dir)

    writer_dict['writer'].close()
예제 #12
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_edges2shoes_loader('train', config.batch_size)
    val_loader = get_edges2shoes_loader('val', config.val_batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)
    D_A = Discriminator().to(device)
    D_B = Discriminator().to(device)

    # Loss Function #
    criterion_Adversarial = nn.BCELoss()
    criterion_Recon = nn.MSELoss()
    criterion_Feature = nn.HingeEmbeddingLoss()

    # Optimizers #
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.00001)
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters()),
                               config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.00001)

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, G_losses = [], []

    # Constants #
    iters = 0

    # Training #
    print("Training DiscoGAN started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(train_loader):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Models #
            G_A2B.zero_grad()
            G_B2A.zero_grad()
            D_A.zero_grad()
            D_B.zero_grad()

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ################
            # Forward Data #
            ################

            fake_B = G_A2B(real_A)
            fake_A = G_B2A(real_B)

            prob_real_A, A_real_features = D_A(real_A)
            prob_fake_A, A_fake_features = D_A(fake_A)

            prob_real_B, B_real_features = D_B(real_B)
            prob_fake_B, B_fake_features = D_B(fake_B)

            #######################
            # Train Discriminator #
            #######################

            # Discriminator A #
            real_labels = Variable(torch.ones(prob_real_A.size()),
                                   requires_grad=False).to(device)
            D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            fake_labels = Variable(torch.zeros(prob_fake_A.size()),
                                   requires_grad=False).to(device)
            D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels)

            D_loss_A = (D_real_loss_A + D_fake_loss_A).mean()

            # Discriminator B #
            real_labels = Variable(torch.ones(prob_real_B.size()),
                                   requires_grad=False).to(device)
            D_real_loss_B = criterion_Adversarial(prob_real_B, real_labels)

            fake_labels = Variable(torch.zeros(prob_fake_B.size()),
                                   requires_grad=False).to(device)
            D_fake_loss_B = criterion_Adversarial(prob_fake_B, fake_labels)

            D_loss_B = (D_real_loss_B + D_fake_loss_B).mean()

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            ###################
            # Train Generator #
            ###################

            # Adversarial Loss #
            real_labels = Variable(torch.ones(prob_real_A.size()),
                                   requires_grad=False).to(device)
            G_adv_loss_A = criterion_Adversarial(prob_fake_A, real_labels)

            real_labels = Variable(torch.ones(prob_real_B.size()),
                                   requires_grad=False).to(device)
            G_adv_loss_B = criterion_Adversarial(prob_fake_B, real_labels)

            # Feature Loss #
            G_feature_loss_A = feature_loss(criterion_Feature, A_real_features,
                                            A_fake_features)
            G_feature_loss_B = feature_loss(criterion_Feature, B_real_features,
                                            B_fake_features)

            # Reconstruction Loss #
            fake_ABA = G_B2A(fake_B)
            fake_BAB = G_A2B(fake_A)

            G_recon_loss_A = criterion_Recon(fake_ABA, real_A)
            G_recon_loss_B = criterion_Recon(fake_BAB, real_B)

            if iters < config.decay_gan_loss:
                rate = config.starting_rate
            else:
                print("Now the rate is changed to {}".format(
                    config.changed_rate))
                rate = config.changed_rate

            G_loss_A = (G_adv_loss_A * 0.1 + G_feature_loss_A * 0.9) * (
                1. - rate) + G_recon_loss_A * rate
            G_loss_B = (G_adv_loss_B * 0.1 + G_feature_loss_B * 0.9) * (
                1. - rate) + G_recon_loss_B * rate

            # Calculate Total Generator Loss #
            G_loss = G_loss_A + G_loss_B

            # Back Propagation and Update #
            if iters % config.num_train_gen == 0:
                D_loss.backward()
                D_optim.step()
            else:
                G_loss.backward()
                G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "DiscoGAN | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G_A2B, G_B2A, epoch,
                              config.samples_path)

            iters += 1

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G_A2B.state_dict(),
                os.path.join(
                    config.weights_path,
                    'DiscoGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(
                    config.weights_path,
                    'DiscoGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train('DiscoGAN', config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    print("Training finished.")
예제 #13
0
파일: main.py 프로젝트: omihub777/sim-real
    project_name="sim_real",
    auto_metric_logging=True,
    auto_param_logging=True,
)

if args.mixed_precision:
    print("Applied: Mixed Precision")
    tf.keras.mixed_precision.set_global_policy("mixed_float16")

train_ds, test_ds = get_dataset(args)
grid = image_grid(next(iter(train_ds))[0])[0]
logger.log_image(grid.numpy())
model = get_model(args)
criterion = get_criterion(args)
optimizer = get_optimizer(args)
lr_scheduler = get_lr_scheduler(args)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', mode='max', patience=args.patience, restore_best_weights=True)
experiment_name = get_experiment_name(args)
logger.set_name(experiment_name)
logger.log_parameters(vars(args))
with logger.train():
    filename =f'{args.model_name}.hdf5'
    checkpoint = tf.keras.callbacks.ModelCheckpoint(filename, monitor='val_accuracy', mode='max', save_best_only=True, verbose=True)

    model.compile(loss=criterion, optimizer=optimizer, metrics=['accuracy'])
    if args.dry_run:
        print("[INFO] Turn off all callbacks")
        model.fit(train_ds, validation_data=test_ds, epochs=args.epochs, steps_per_epoch=2)
    else:
        model.fit(train_ds, validation_data=test_ds, epochs=args.epochs, callbacks=[lr_scheduler, early_stop, checkpoint])
예제 #14
0
def main(config):

    # For Reproducibility #
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config.seed)

    # Weights and Plots Path #
    paths = [config.weights_path, config.plots_path]
    for path in paths:
        make_dirs(path)

    # Prepare Data Loader #
    if config.dataset == 'cifar':
        train_loader, val_loader, test_loader = cifar_loader(
            config.num_classes, config.batch_size)
        input_size = 32

    # Prepare Networks #
    if config.model == 'vit':
        model = VisionTransformer(in_channels=config.in_channels,
                                  embed_dim=config.embed_dim,
                                  patch_size=config.patch_size,
                                  num_layers=config.num_layers,
                                  num_heads=config.num_heads,
                                  mlp_dim=config.mlp_dim,
                                  dropout=config.drop_out,
                                  input_size=input_size,
                                  num_classes=config.num_classes).to(device)

    elif config.model == 'efficient':
        model = EfficientNet.from_name(
            'efficientnet-b0', num_classes=config.num_classes).to(device)

    elif config.model == 'resnet':
        model = resnet34(pretrained=False).to(device)
        model.fc = nn.Linear(config.mlp_dim, config.num_classes).to(device)

    else:
        raise NotImplementedError

    # Weight Initialization #
    if not config.model == 'efficient':
        if config.init == 'normal':
            model.apply(init_weights_normal)
        elif config.init == 'xavier':
            model.apply(init_weights_xavier)
        elif config.init == 'he':
            model.apply(init_weights_kaiming)
        else:
            raise NotImplementedError

    # Train #
    if config.phase == 'train':

        # Loss Function #
        criterion = nn.CrossEntropyLoss()

        # Optimizers #
        if config.num_classes == 10:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=config.lr,
                                         betas=(0.5, 0.999))
            optimizer_scheduler = get_lr_scheduler(config.lr_scheduler,
                                                   optimizer)
        elif config.num_classes == 100:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=config.lr,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            optimizer_scheduler = get_lr_scheduler('step', optimizer)

        # Constants #
        best_top1_acc = 0

        # Lists #
        train_losses, val_losses = list(), list()
        train_top1_accs, train_top5_accs = list(), list()
        val_top1_accs, val_top5_accs = list(), list()

        # Train and Validation #
        print("Training {} has started.".format(model.__class__.__name__))
        for epoch in range(config.num_epochs):

            # Train #
            train_loss, train_top1_acc, train_top5_acc = train(
                train_loader, model, optimizer, criterion, epoch, config)

            # Validation #
            val_loss, val_top1_acc, val_top5_acc = validate(
                val_loader, model, criterion, epoch, config)

            # Add items to Lists #
            train_losses.append(train_loss)
            val_losses.append(val_loss)

            train_top1_accs.append(train_top1_acc)
            train_top5_accs.append(train_top5_acc)

            val_top1_accs.append(val_top1_acc)
            val_top5_accs.append(val_top5_acc)

            # If Best Top 1 Accuracy #
            if val_top1_acc > best_top1_acc:
                best_top1_acc = max(val_top1_acc, best_top1_acc)

                # Save Models #
                print("The best model is saved!")
                torch.save(
                    model.state_dict(),
                    os.path.join(
                        config.weights_path,
                        'BEST_{}_{}_{}.pkl'.format(model.__class__.__name__,
                                                   str(config.dataset).upper(),
                                                   config.num_classes)))

            print("Best Top 1 Accuracy {:.2f}%\n".format(best_top1_acc))

            # Optimizer Scheduler #
            optimizer_scheduler.step()

        # Plot Losses and Accuracies #
        losses = (train_losses, val_losses)
        accs = (train_top1_accs, train_top5_accs, val_top1_accs, val_top5_accs)
        plot_metrics(losses, accs, config.plots_path, model, config.dataset,
                     config.num_classes)

        print("Training {} using {} {} finished.".format(
            model.__class__.__name__,
            str(config.dataset).upper(), config.num_classes))

    # Test #
    elif config.phase == 'test':

        test(test_loader, model, config)

    else:
        raise NotImplementedError
예제 #15
0
def main(opts):
    """Main function for the training pipeline
    :opts: commandlien arguments
    :returns: None
    """
    ##########################################################################
    #                             Basic settings                             #
    ##########################################################################
    exp_dir = 'experiments'
    log_dir = os.path.join(exp_dir, 'logs')
    model_dir = os.path.join(exp_dir, 'models')
    os.makedirs(os.path.join(model_dir, opts.run_name), exist_ok=True)
    os.makedirs(os.path.join(log_dir, opts.run_name))

    pprint(vars(opts))
    with open(os.path.join(log_dir, opts.run_name, "args.json"), 'w') as f:
        json.dump(vars(opts), f, indent=True)

    torch.manual_seed(opts.seed)
    np.random.seed(opts.seed)
    random.seed(opts.seed)

    ##########################################################################
    #  Define all the necessary variables for model training and evaluation  #
    ##########################################################################
    writer = SummaryWriter(os.path.join(log_dir, opts.run_name), flush_secs=5)

    if opts.train_mode == 'combined':
        train_dataset = get_train_dataset(opts.data_root, opts, opts.folder1,
                                          opts.folder2, opts.folder3)
    elif opts.train_mode == 'oversampling':
        train_dataset = get_train_dataset_by_oversampling(
            opts.data_root, opts, opts.folder1, opts.folder2, opts.folder3)
    elif opts.train_mode == 'pretrain_and_finetune':
        train_dataset, finetune_dataset = get_pretrain_and_finetune_datast(
            opts.data_root, opts, opts.folder1, opts.folder2, opts.folder3)
        finetune_loader = torch.utils.data.DataLoader(
            finetune_dataset,
            batch_size=opts.batch_size,
            num_workers=opts.num_workers,
            drop_last=False,
            shuffle=True)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opts.batch_size,
                                               num_workers=opts.num_workers,
                                               drop_last=False,
                                               shuffle=True)

    val_dataset = get_val_dataset(os.path.join('data', 'val'), opts)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=opts.eval_batch_size,
                                             shuffle=False,
                                             num_workers=opts.num_workers,
                                             drop_last=False)

    test_dataset = get_test_dataset(os.path.join('data', 'test'), opts)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=opts.eval_batch_size,
                                              shuffle=False,
                                              num_workers=opts.num_workers,
                                              drop_last=False)

    assert train_dataset.class_to_idx == val_dataset.class_to_idx == test_dataset.class_to_idx, "Mapping not correct"

    model = get_model(opts)

    opts.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if torch.cuda.device_count() > 1 and not opts.no_data_parallel:
        model = nn.DataParallel(model)

    model = model.to(opts.device)

    optimizer = optim.RMSprop(model.parameters(),
                              lr=opts.lr,
                              alpha=0.9,
                              weight_decay=1e-5,
                              momentum=0.9)
    scheduler = get_lr_scheduler(optimizer, opts)

    best_val_loss = float('inf')
    best_val_accu = float(0)
    best_val_rec = float(0)
    best_val_prec = float(0)
    best_val_f1 = float(0)
    best_val_auc = float(0)

    iteration_change_loss = 0
    t_start_training = time.time()

    ##########################################################################
    #                           Main training loop                           #
    ##########################################################################
    for epoch in range(opts.epochs):
        current_lr = get_lr(optimizer)
        t_start = time.time()

        ############################################################
        #  The actual training and validation step for each epoch  #
        ############################################################
        train_loss, train_metric = train_model(model, train_loader, optimizer,
                                               opts)

        if epoch == opts.finetune_epoch and opts.train_mode == 'pretrain_and_finetune':
            train_loader = finetune_loader
            optimizer = optim.RMSprop(model.parameters(),
                                      lr=opts.lr,
                                      alpha=0.9,
                                      weight_decay=1e-5,
                                      momentum=0.9)
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=opts.step_size_finetuning,
                gamma=opts.gamma)

        # Run the validation set
        with torch.no_grad():
            val_loss, val_metric = evaluate_model(model, val_loader, opts)

        ##############################
        #  Write to summary writer   #
        ##############################

        train_acc, val_acc = train_metric['accuracy'], val_metric['accuracy']
        train_rec, val_rec = train_metric['recalls'], val_metric['recalls']
        train_prec, val_prec = train_metric['precisions'], val_metric[
            'precisions']
        train_f1, val_f1 = train_metric['f1'], val_metric['f1']
        train_auc, val_auc = train_metric['auc'], val_metric['auc']

        writer.add_scalar('Loss/Train', train_loss, epoch)
        writer.add_scalar('Accuracy/Train', train_acc, epoch)
        writer.add_scalar('Precision/Train', train_prec, epoch)
        writer.add_scalar('Recall/Train', train_rec, epoch)
        writer.add_scalar('F1/Train', train_f1, epoch)
        writer.add_scalar('AUC/Train', train_auc, epoch)

        writer.add_scalar('Loss/Val', val_loss, epoch)
        writer.add_scalar('Accuracy/Val', val_acc, epoch)
        writer.add_scalar('Precision/Val', val_prec, epoch)
        writer.add_scalar('Recall/Val', val_rec, epoch)
        writer.add_scalar('F1/Val', val_f1, epoch)
        writer.add_scalar('AUC/Val', val_auc, epoch)

        ##############################
        #  Adjust the learning rate  #
        ##############################
        if opts.lr_scheduler == 'plateau':
            scheduler.step(val_loss)
        elif opts.lr_scheduler in ['step', 'cosine']:
            scheduler.step()

        t_end = time.time()
        delta = t_end - t_start

        print_epoch_progress(epoch, opts.epochs, train_loss, val_loss, delta,
                             train_metric, val_metric)
        iteration_change_loss += 1
        print('-' * 30)

        if val_acc > best_val_accu:
            best_val_accu = val_acc
            if bool(opts.save_model):
                torch.save(
                    model.state_dict(),
                    os.path.join(model_dir, opts.run_name,
                                 'best_state_dict.pth'))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            iteration_change_loss = 0

        if val_rec > best_val_rec:
            best_val_rec = val_rec

        if val_prec > best_val_prec:
            best_val_prec = val_prec

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            print(f'The best validation F1-score is now {best_val_f1}')
            print(
                f'The validation accuracy and AUC are now {val_acc} and {val_auc}'
            )

        if val_auc > best_val_auc:
            best_val_auc = val_auc

        if iteration_change_loss == opts.patience and opts.early_stopping:
            print(
                ('Early stopping after {0} iterations without the decrease ' +
                 'of the val loss').format(iteration_change_loss))
            break

    t_end_training = time.time()
    print(f'training took {t_end_training - t_start_training}s')
    print(f'Best validation accuracy: {best_val_accu}')
    print(f'Best validation loss: {best_val_loss}')
    print(f'Best validation precision: {best_val_prec}')
    print(f'Best validation recall: {best_val_rec}')
    print(f'Best validation f1: {best_val_f1}')
    print(f'Best validation AUC: {best_val_auc}')

    with torch.no_grad():
        if opts.train_mode in ['combined', 'oversampling']:
            model.load_state_dict(
                torch.load(
                    os.path.join(model_dir, opts.run_name,
                                 'best_state_dict.pth')))
        test_loss, test_metric = evaluate_model(model, test_loader, opts)

    print(f'The best test F1: {test_metric["f1"]}')
    print(f'The best test auc: {test_metric["auc"]}')
    print(f'The best test accuracy: {test_metric["accuracy"]}')
예제 #16
0
def train(args, logger, tb_writer):
    logger.info('Args: {}'.format(json.dumps(vars(args), indent=4, sort_keys=True)))
    if args.local_rank in [-1, 0]:
        with open(os.path.join(args.save_dir, 'args.yaml'), 'w') as file:
            yaml.safe_dump(vars(args), file, sort_keys=False)

    device_id = args.local_rank if args.local_rank != -1 else 0
    device = torch.device('cuda', device_id)
    logger.warning(f'Using GPU {args.local_rank}.')

    world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1
    logger.info(f'Total number of GPUs used: {world_size}.')
    effective_batch_size = args.batch_size * world_size * args.accumulation_steps
    logger.info(f'Effective batch size: {effective_batch_size}.')

    num_train_samples_per_epoch, num_dev_samples, num_unique_train_epochs = get_data_sizes(data_dir=args.data_dir,
                                                                                           num_epochs=args.num_epochs,
                                                                                           logger=logger)
    num_optimization_steps = sum(num_train_samples_per_epoch) // world_size // args.batch_size // \
                             args.accumulation_steps
    if args.max_steps > 0:
        num_optimization_steps = min(num_optimization_steps, args.max_steps)
    logger.info(f'Total number of optimization steps: {num_optimization_steps}.')

    # Set random seed
    logger.info(f'Using random seed {args.seed}.')
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get model
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    logger.info(f'Loading model {args.model} for task {args.task}...')
    model = ModelRegistry.get_model(args.task).from_pretrained(args.model)

    if args.local_rank in [-1, 0]:
        with open(os.path.join(args.save_dir, 'config.json'), 'w') as file:
            json.dump(model.config.__dict__, file)

    if args.local_rank == 0:
        torch.distributed.barrier()

    model.to(device)

    # Get optimizer
    logger.info('Creating optimizer...')
    parameter_groups = get_parameter_groups(model)
    optimizer = AdamW(parameter_groups, lr=args.learning_rate, weight_decay=args.weight_decay, eps=1e-8)
    scheduler = get_lr_scheduler(optimizer, num_steps=num_optimization_steps, warmup_proportion=args.warmup_proportion)

    if args.amp:
        amp.register_half_function(torch, 'einsum')
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level)

    if args.local_rank != -1:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    # Get dev data loader
    dev_data_file = os.path.join(args.data_dir, f'dev.jsonl.gz')
    logger.info(f'Creating dev dataset from {dev_data_file}...')
    dev_dataset = DatasetRegistry.get_dataset(args.task)(data_file=dev_data_file,
                                                         data_size=num_dev_samples,
                                                         local_rank=-1)
    dev_loader = DataLoader(dev_dataset,
                            batch_size=2 * args.batch_size,
                            num_workers=1,
                            collate_fn=dev_dataset.collate_fn)

    # Get evaluator
    evaluator = EvaluatorRegistry.get_evaluator(args.task)(data_loader=dev_loader,
                                                           logger=logger,
                                                           tb_writer=tb_writer,
                                                           device=device,
                                                           world_size=world_size,
                                                           args=args)

    # Get saver
    saver = CheckpointSaver(save_dir=args.save_dir,
                            max_checkpoints=args.max_checkpoints,
                            primary_metric=evaluator.primary_metric,
                            maximize_metric=evaluator.maximize_metric,
                            logger=logger)

    global_step = 0
    samples_processed = 0

    # Train
    logger.info('Training...')
    samples_till_eval = args.eval_every
    for epoch in range(1, args.num_epochs + 1):
        # Get train data loader for current epoch
        train_data_file_num = ((epoch - 1) % num_unique_train_epochs) + 1
        train_data_file = os.path.join(args.data_dir, f'epoch_{train_data_file_num}.jsonl.gz')
        logger.info(f'Creating training dataset from {train_data_file}...')
        train_dataset = DatasetRegistry.get_dataset(args.task)(train_data_file,
                                                               data_size=num_train_samples_per_epoch[epoch - 1],
                                                               local_rank=args.local_rank,
                                                               world_size=world_size)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=1,
                                  collate_fn=train_dataset.collate_fn)

        logger.info(f'Starting epoch {epoch}...')
        model.train()
        model.zero_grad()
        loss_values = defaultdict(float)
        samples_till_end = (num_optimization_steps - global_step) * effective_batch_size
        samples_in_cur_epoch = min([len(train_loader.dataset), samples_till_end])
        disable_progress_bar = (args.local_rank not in [-1, 0])
        with tqdm(total=samples_in_cur_epoch, disable=disable_progress_bar) as progress_bar:
            for step, batch in enumerate(train_loader, 1):
                batch = {name: tensor.to(device) for name, tensor in batch.items()}
                current_batch_size = batch['input_ids'].shape[0]

                outputs = model(**batch)
                loss, current_loss_values = outputs[:2]

                loss = loss / args.accumulation_steps
                for name, value in current_loss_values.items():
                    loss_values[name] += value / args.accumulation_steps

                if args.amp:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                samples_processed += current_batch_size * world_size
                samples_till_eval -= current_batch_size * world_size
                progress_bar.update(current_batch_size * world_size)

                if step % args.accumulation_steps == 0:
                    current_lr = scheduler.get_last_lr()[0]

                    if args.amp:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    # Log info
                    progress_bar.set_postfix(epoch=epoch, step=global_step, lr=current_lr, **loss_values)
                    if args.local_rank in [-1, 0]:
                        tb_writer.add_scalar('train/LR', current_lr, global_step)
                        for name, value in loss_values.items():
                            tb_writer.add_scalar(f'train/{name}', value, global_step)
                    loss_values = {name: 0 for name in loss_values}

                    if global_step == args.max_steps:
                        logger.info('Reached maximum number of optimization steps.')
                        break

                    if samples_till_eval <= 0:
                        samples_till_eval = args.eval_every
                        eval_results = evaluator.evaluate(model, global_step)
                        if args.local_rank in [-1, 0]:
                            saver.save(model, global_step, eval_results)

            if not args.do_not_eval_after_epoch:
                eval_results = evaluator.evaluate(model, global_step)
                if args.local_rank in [-1, 0]:
                    saver.save(model, global_step, eval_results)
예제 #17
0
def train(opt, train_iter, dev_iter, test_iter, syn_data, verbose=True):
    global_start = time.time()
    #logger = utils.getLogger()
    model = models.setup(opt)

    if opt.resume != None:
        model = set_params(model, opt.resume)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.is_available():
        model.cuda()
        #model=torch.nn.DataParallel(model)

    # set optimizer
    if opt.embd_freeze == True:
        model.embedding.weight.requires_grad = False
    else:
        model.embedding.weight.requires_grad = True
    params = [param for param in model.parameters() if param.requires_grad
              ]  #filter(lambda p: p.requires_grad, model.parameters())
    optimizer = utils.getOptimizer(params,
                                   name=opt.optimizer,
                                   lr=opt.learning_rate,
                                   weight_decay=opt.weight_decay,
                                   scheduler=utils.get_lr_scheduler(
                                       opt.lr_scheduler))
    scheduler = WarmupMultiStepLR(optimizer, (40, 80), 0.1, 1.0 / 10.0, 2,
                                  'linear')

    from label_smooth import LabelSmoothSoftmaxCE
    if opt.label_smooth != 0:
        assert (opt.label_smooth <= 1 and opt.label_smooth > 0)
        loss_fun = LabelSmoothSoftmaxCE(lb_pos=1 - opt.label_smooth,
                                        lb_neg=opt.label_smooth)
    else:
        loss_fun = F.cross_entropy

    filename = None
    acc_adv_list = []
    start = time.time()
    kl_control = 0

    # initialize synonyms with the same embd
    from PWWS.word_level_process import word_process, get_tokenizer
    tokenizer = get_tokenizer(opt)

    if opt.embedding_prep == "same":
        father_dict = {}
        for index in range(1 + len(tokenizer.index_word)):
            father_dict[index] = index

        def get_father(x):
            if father_dict[x] == x:
                return x
            else:
                fa = get_father(father_dict[x])
                father_dict[x] = fa
                return fa

        for index in range(len(syn_data) - 1, 0, -1):
            syn_list = syn_data[index]
            for pos in syn_list:
                fa_pos = get_father(pos)
                fa_anch = get_father(index)
                if fa_pos == fa_anch:
                    father_dict[index] = index
                    father_dict[fa_anch] = index
                else:
                    father_dict[index] = index
                    father_dict[fa_anch] = index
                    father_dict[fa_pos] = index

        print("Same embedding for synonyms as embd prep.")
        set_different_embd = set()
        for key in father_dict:
            fa = get_father(key)
            set_different_embd.add(fa)
            with torch.no_grad():
                model.embedding.weight[key, :] = model.embedding.weight[fa, :]
        print(len(set_different_embd))

    elif opt.embedding_prep == "ge":
        print("Graph embedding as embd prep.")
        ge_file_path = opt.ge_file_path
        f = open(ge_file_path, 'rb')
        saved = pickle.load(f)
        ge_embeddings_dict = saved['walk_embeddings']
        #model = saved['model']
        f.close()
        with torch.no_grad():
            for key in ge_embeddings_dict:
                model.embedding.weight[int(key), :] = torch.FloatTensor(
                    ge_embeddings_dict[key])
    else:
        print("No embd prep.")

    from from_certified.attack_surface import WordSubstitutionAttackSurface, LMConstrainedAttackSurface
    if opt.lm_constraint:
        attack_surface = LMConstrainedAttackSurface.from_files(
            opt.certified_neighbors_file_path, opt.imdb_lm_file_path)
    else:
        attack_surface = WordSubstitutionAttackSurface.from_files(
            opt.certified_neighbors_file_path, opt.imdb_lm_file_path)

    best_adv_acc = 0
    for epoch in range(21):

        if opt.smooth_ce:
            if epoch < 10:
                weight_adv = epoch * 1.0 / 10
                weight_clean = 1 - weight_adv
            else:
                weight_adv = 1
                weight_clean = 0
        else:
            weight_adv = opt.weight_adv
            weight_clean = opt.weight_clean

        if epoch >= opt.kl_start_epoch:
            kl_control = 1

        sum_loss = sum_loss_adv = sum_loss_kl = sum_loss_clean = 0
        total = 0

        for iters, batch in enumerate(train_iter):

            text = batch[0].to(device)
            label = batch[1].to(device)
            anch = batch[2].to(device)
            pos = batch[3].to(device)
            neg = batch[4].to(device)
            anch_valid = batch[5].to(device).unsqueeze(2)
            text_like_syn = batch[6].to(device)
            text_like_syn_valid = batch[7].to(device)

            bs, sent_len = text.shape

            model.train()

            # zero grad
            optimizer.zero_grad()

            if opt.pert_set == "ad_text":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'w_optm_lr': opt.w_optm_lr,
                    'sparse_weight': opt.attack_sparse_weight,
                    'out_type': "text"
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                n, l, s = text_like_syn.shape
                text_like_syn_embd = model(mode="text_to_embd",
                                           input=text_like_syn.reshape(
                                               n, l * s)).reshape(n, l, s, -1)
                text_adv = model(mode="get_adv_by_convex_syn",
                                 input=embd,
                                 label=label,
                                 text_like_syn_embd=text_like_syn_embd,
                                 text_like_syn_valid=text_like_syn_valid,
                                 text_like_syn=text_like_syn,
                                 attack_type_dict=attack_type_dict)

            elif opt.pert_set == "ad_text_syn_p":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'w_optm_lr': opt.w_optm_lr,
                    'sparse_weight': opt.train_attack_sparse_weight,
                    'out_type': "comb_p"
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                n, l, s = text_like_syn.shape
                text_like_syn_embd = model(mode="text_to_embd",
                                           input=text_like_syn.reshape(
                                               n, l * s)).reshape(n, l, s, -1)
                adv_comb_p = model(mode="get_adv_by_convex_syn",
                                   input=embd,
                                   label=label,
                                   text_like_syn_embd=text_like_syn_embd,
                                   text_like_syn_valid=text_like_syn_valid,
                                   attack_type_dict=attack_type_dict)

            elif opt.pert_set == "ad_text_hotflip":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                }
                text_adv = model(mode="get_adv_hotflip",
                                 input=text,
                                 label=label,
                                 text_like_syn_valid=text_like_syn_valid,
                                 text_like_syn=text_like_syn,
                                 attack_type_dict=attack_type_dict)

            elif opt.pert_set == "l2_ball":
                set_radius = opt.train_attack_eps
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'step_size': opt.train_attack_step_size * set_radius,
                    'random_start': opt.random_start,
                    'epsilon': set_radius,
                    #'loss_func': 'ce',
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'direction': 'away',
                    'ball_range': opt.l2_ball_range,
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                embd_adv = model(mode="get_embd_adv",
                                 input=embd,
                                 label=label,
                                 attack_type_dict=attack_type_dict)

            optimizer.zero_grad()
            # clean loss
            predicted = model(mode="text_to_logit", input=text)
            loss_clean = loss_fun(predicted, label)
            # adv loss
            if opt.pert_set == "ad_text" or opt.pert_set == "ad_text_hotflip":
                predicted_adv = model(mode="text_to_logit", input=text_adv)
            elif opt.pert_set == "ad_text_syn_p":
                predicted_adv = model(mode="text_syn_p_to_logit",
                                      input=text_like_syn,
                                      comb_p=adv_comb_p)
            elif opt.pert_set == "l2_ball":
                predicted_adv = model(mode="embd_to_logit", input=embd_adv)

            loss_adv = loss_fun(predicted_adv, label)
            # kl loss
            criterion_kl = nn.KLDivLoss(reduction="sum")
            loss_kl = (1.0 / bs) * criterion_kl(
                F.log_softmax(predicted_adv, dim=1), F.softmax(predicted,
                                                               dim=1))

            # optimize
            loss = opt.weight_kl * kl_control * loss_kl + weight_adv * loss_adv + weight_clean * loss_clean
            loss.backward()
            optimizer.step()
            sum_loss += loss.item()
            sum_loss_adv += loss_adv.item()
            sum_loss_clean += loss_clean.item()
            sum_loss_kl += loss_kl.item()
            predicted, idx = torch.max(predicted, 1)
            precision = (idx == label).float().mean().item()
            predicted_adv, idx = torch.max(predicted_adv, 1)
            precision_adv = (idx == label).float().mean().item()
            total += 1

            out_log = "%d epoch %d iters: loss: %.3f, loss_kl: %.3f, loss_adv: %.3f, loss_clean: %.3f | acc: %.3f acc_adv: %.3f | in %.3f seconds" % (
                epoch, iters, sum_loss / total, sum_loss_kl / total,
                sum_loss_adv / total, sum_loss_clean / total, precision,
                precision_adv, time.time() - start)
            start = time.time()
            print(out_log)

        scheduler.step()

        if epoch % 1 == 0:
            acc = utils.imdb_evaluation(opt, device, model, dev_iter)
            out_log = "%d epoch with dev acc %.4f" % (epoch, acc)
            print(out_log)
            adv_acc = utils.imdb_evaluation_ascc_attack(
                opt, device, model, dev_iter, tokenizer)
            out_log = "%d epoch with dev adv acc against ascc attack %.4f" % (
                epoch, adv_acc)
            print(out_log)

            #hotflip_adv_acc=utils.evaluation_hotflip_adv(opt, device, model, dev_iter, tokenizer)
            #out_log="%d epoch with dev hotflip adv acc %.4f" % (epoch,hotflip_adv_acc)
            #logger.info(out_log)
            #print(out_log)

            if adv_acc >= best_adv_acc:
                best_adv_acc = adv_acc
                best_save_dir = os.path.join(opt.out_path,
                                             "{}_best.pth".format(opt.model))
                state = {
                    'net': model.state_dict(),
                    'epoch': epoch,
                }
                torch.save(state, best_save_dir)

    # restore best according to dev set
    model = set_params(model, best_save_dir)
    acc = utils.imdb_evaluation(opt, device, model, test_iter)
    print("test acc %.4f" % (acc))
    adv_acc = utils.imdb_evaluation_ascc_attack(opt, device, model, test_iter,
                                                tokenizer)
    print("test adv acc against ascc attack %.4f" % (adv_acc))
    genetic_attack(opt,
                   device,
                   model,
                   attack_surface,
                   dataset=opt.dataset,
                   genetic_test_num=opt.genetic_test_num)
    fool_text_classifier_pytorch(opt,
                                 device,
                                 model,
                                 dataset=opt.dataset,
                                 clean_samples_cap=opt.pwws_test_num)
def main(args):
    with open(args.config, 'r') as f:
        y = yaml.load(f, Loader=yaml.Loader)
    cfg = addict.Dict(y)
    cfg.general.config = args.config

    # misc
    device = cfg.general.device
    random.seed(cfg.general.random_state)
    os.environ['PYTHONHASHSEED'] = str(cfg.general.random_state)
    np.random.seed(cfg.general.random_state)
    torch.manual_seed(cfg.general.random_state)

    # log
    if cfg.general.expid == '':
        expid = dt.datetime.now().strftime('%Y%m%d%H%M%S')
        cfg.general.expid = expid
    else:
        expid = cfg.general.expid
    cfg.general.logdir = str(LOGDIR / expid)
    if not os.path.exists(cfg.general.logdir):
        os.makedirs(cfg.general.logdir)
    os.chmod(cfg.general.logdir, 0o777)
    logger = utils.get_logger(os.path.join(cfg.general.logdir, 'main.log'))
    logger.info(f'Logging at {cfg.general.logdir}')
    logger.info(cfg)
    shutil.copyfile(str(args.config), cfg.general.logdir + '/config.yaml')
    writer = SummaryWriter(cfg.general.logdir)

    # data
    X_train = np.load(cfg.data.X_train, allow_pickle=True)
    y_train = np.load(cfg.data.y_train, allow_pickle=True)
    logger.info('Loaded X_train, y_train')
    # CV
    kf = model_selection.__dict__[cfg.training.split](
        n_splits=cfg.training.n_splits,
        shuffle=True,
        random_state=cfg.general.random_state)  # noqa
    score_list = {'loss': [], 'score': []}
    for fold_i, (train_idx, valid_idx) in enumerate(
            kf.split(X=np.zeros(len(y_train)), y=y_train[:, 0])):
        if fold_i + 1 not in cfg.training.target_folds:
            continue
        X_train_ = X_train[train_idx]
        y_train_ = y_train[train_idx]
        X_valid_ = X_train[valid_idx]
        y_valid_ = y_train[valid_idx]
        _ratio = cfg.training.get('with_x_percent_fold_1_of_5', 0.)
        if _ratio > 0.:
            assert cfg.training.n_splits == 5 and fold_i + 1 == 1
            from sklearn.model_selection import train_test_split
            if _ratio == 0.95:
                _test_size = 0.25
            elif _ratio == 0.9:
                _test_size = 0.5
            else:
                raise NotImplementedError
            _X_train, X_valid_, _y_train, y_valid_ = train_test_split(
                X_valid_,
                y_valid_,
                test_size=_test_size,
                random_state=cfg.general.random_state)
            X_train_ = np.concatenate([X_train_, _X_train], axis=0)
            y_train_ = np.concatenate([y_train_, _y_train], axis=0)
        train_set = Dataset(X_train_, y_train_, cfg, mode='train')
        valid_set = Dataset(X_valid_, y_valid_, cfg, mode='valid')
        if fold_i == 0:
            logger.info(train_set.transform)
            logger.info(valid_set.transform)
        train_loader = DataLoader(train_set,
                                  batch_size=cfg.training.batch_size,
                                  shuffle=True,
                                  num_workers=cfg.training.n_worker,
                                  pin_memory=True)
        valid_loader = DataLoader(valid_set,
                                  batch_size=cfg.training.batch_size,
                                  shuffle=False,
                                  num_workers=cfg.training.n_worker,
                                  pin_memory=True)

        # model
        model = models.get_model(cfg=cfg)
        model = model.to(device)
        criterion = loss.get_loss_fn(cfg)
        optimizer = utils.get_optimizer(model.parameters(), config=cfg)
        scheduler = utils.get_lr_scheduler(optimizer, config=cfg)

        start_epoch = 1
        best = {'loss': 1e+9, 'score': -1.}
        is_best = {'loss': False, 'score': False}

        # resume
        if cfg.model.resume:
            if os.path.isfile(cfg.model.resume):
                checkpoint = torch.load(cfg.model.resume)
                start_epoch = checkpoint['epoch'] + 1
                best['loss'] = checkpoint['loss/best']
                best['score'] = checkpoint['score/best']
                if cfg.general.multi_gpu:
                    model.load_state_dict(
                        utils.fix_model_state_dict(checkpoint['state_dict']))
                else:
                    model.load_state_dict(checkpoint['state_dict'])
                if cfg.model.get('load_optimizer', True):
                    optimizer.load_state_dict(checkpoint['optimizer'])
                logger.info('Loaded checkpoint {} (epoch {})'.format(
                    cfg.model.resume, start_epoch - 1))
            else:
                raise IOError('No such file {}'.format(cfg.model.resume))

        if cfg.general.multi_gpu:
            model = nn.DataParallel(model)

        for epoch_i in range(start_epoch, cfg.training.epochs + 1):
            if scheduler is not None:
                if cfg.training.lr_scheduler.name == 'MultiStepLR':
                    optimizer.zero_grad()
                    optimizer.step()
                    scheduler.step()
            for param_group in optimizer.param_groups:
                current_lr = param_group['lr']
            _ohem_loss = (cfg.training.ohem_loss
                          and cfg.training.ohem_epoch < epoch_i)
            train = training(train_loader,
                             model,
                             criterion,
                             optimizer,
                             config=cfg,
                             using_ohem_loss=_ohem_loss,
                             lr=current_lr)
            valid = training(valid_loader,
                             model,
                             criterion,
                             optimizer,
                             is_training=False,
                             config=cfg,
                             lr=current_lr)

            if scheduler is not None and cfg.training.lr_scheduler.name != 'MultiStepLR':
                if cfg.training.lr_scheduler.name == 'ReduceLROnPlateau':
                    if scheduler.mode == 'min':
                        value = valid['loss']
                    elif scheduler.mode == 'max':
                        value = valid['score']
                    else:
                        raise NotImplementedError
                    scheduler.step(value)
                else:
                    scheduler.step()

            is_best['loss'] = valid['loss'] < best['loss']
            is_best['score'] = valid['score'] > best['score']
            if is_best['loss']:
                best['loss'] = valid['loss']
            if is_best['score']:
                best['score'] = valid['score']
            model_state_dict = model.module.state_dict(
            ) if cfg.general.multi_gpu else model.state_dict()  # noqa
            state_dict = {
                'epoch': epoch_i,
                'state_dict': model_state_dict,
                'optimizer': optimizer.state_dict(),
                'loss/valid': valid['loss'],
                'score/valid': valid['score'],
                'loss/best': best['loss'],
                'score/best': best['score'],
            }
            utils.save_checkpoint(
                state_dict,
                is_best,
                epoch_i,
                valid['loss'],
                valid['score'],
                Path(cfg.general.logdir) / f'fold_{fold_i}',
            )

            # tensorboard
            writer.add_scalar('Loss/Train', train['loss'], epoch_i)
            writer.add_scalar('Loss/Valid', valid['loss'], epoch_i)
            writer.add_scalar('Loss/Best', best['loss'], epoch_i)
            writer.add_scalar('Metrics/Train', train['score'], epoch_i)
            writer.add_scalar('Metrics/Valid', valid['score'], epoch_i)
            writer.add_scalar('Metrics/Best', best['score'], epoch_i)

            log = f'[{expid}] Fold {fold_i+1} Epoch {epoch_i}/{cfg.training.epochs} '
            log += f'[loss] {train["loss"]:.6f}/{valid["loss"]:.6f} '
            log += f'[score] {train["score"]:.6f}/{valid["score"]:.6f} '
            log += f'({best["score"]:.6f}) '
            log += f'lr {current_lr:.6f}'
            logger.info(log)

        score_list['loss'].append(best['loss'])
        score_list['score'].append(best['score'])
        if cfg.training.single_fold: break  # noqa

    log = f'[{expid}] '
    log += f'[loss] {cfg.training.n_splits}-fold/mean {np.mean(score_list["loss"]):.4f} '
    log += f'[score] {cfg.training.n_splits}-fold/mean {np.mean(score_list["score"]):.4f} '  # noqa
    logger.info(log)
예제 #19
0
def train(args):

    # Device Configuration #
    device = torch.device(
        f'cuda:{args.gpu_num}' if torch.cuda.is_available() else 'cpu')

    # Fix Seed for Reproducibility #
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Samples, Plots, Weights and CSV Path #
    paths = [
        args.samples_path, args.plots_path, args.weights_path, args.csv_path
    ]
    for path in paths:
        make_dirs(path)

    # Prepare Data #
    data = pd.read_csv(args.data_path)[args.column]

    # Pre-processing #
    scaler_1 = StandardScaler()
    scaler_2 = StandardScaler()
    preprocessed_data = pre_processing(data, scaler_1, scaler_2, args.delta)

    X = moving_windows(preprocessed_data, args.ts_dim)
    label = moving_windows(data.to_numpy(), args.ts_dim)

    # Prepare Networks #
    D = Discriminator(args.ts_dim).to(device)
    G = Generator(args.latent_dim, args.ts_dim,
                  args.conditional_dim).to(device)

    # Loss Function #
    if args.criterion == 'l2':
        criterion = nn.MSELoss()
    elif args.criterion == 'wgangp':
        pass
    else:
        raise NotImplementedError

    # Optimizers #
    D_optim = torch.optim.Adam(D.parameters(), lr=args.lr, betas=(0.5, 0.9))
    G_optim = torch.optim.Adam(G.parameters(), lr=args.lr, betas=(0.5, 0.9))

    D_optim_scheduler = get_lr_scheduler(D_optim, args)
    G_optim_scheduler = get_lr_scheduler(G_optim, args)

    # Lists #
    D_losses, G_losses = list(), list()

    # Train #
    print("Training Time Series GAN started with total epoch of {}.".format(
        args.num_epochs))

    for epoch in range(args.num_epochs):

        # Initialize Optimizers #
        G_optim.zero_grad()
        D_optim.zero_grad()

        if args.criterion == 'l2':
            n_critics = 1
        elif args.criterion == 'wgangp':
            n_critics = 5

        #######################
        # Train Discriminator #
        #######################

        for j in range(n_critics):
            series, start_dates = get_samples(X, label, args.batch_size)

            # Data Preparation #
            series = series.to(device)
            noise = torch.randn(args.batch_size, 1, args.latent_dim).to(device)

            # Adversarial Loss using Real Image #
            prob_real = D(series.float())

            if args.criterion == 'l2':
                real_labels = torch.ones(prob_real.size()).to(device)
                D_real_loss = criterion(prob_real, real_labels)

            elif args.criterion == 'wgangp':
                D_real_loss = -torch.mean(prob_real)

            # Adversarial Loss using Fake Image #
            fake_series = G(noise)
            fake_series = torch.cat(
                (series[:, :, :args.conditional_dim].float(),
                 fake_series.float()),
                dim=2)

            prob_fake = D(fake_series.detach())

            if args.criterion == 'l2':
                fake_labels = torch.zeros(prob_fake.size()).to(device)
                D_fake_loss = criterion(prob_fake, fake_labels)

            elif args.criterion == 'wgangp':
                D_fake_loss = torch.mean(prob_fake)
                D_gp_loss = args.lambda_gp * get_gradient_penalty(
                    D, series.float(), fake_series.float(), device)

            # Calculate Total Discriminator Loss #
            D_loss = D_fake_loss + D_real_loss

            if args.criterion == 'wgangp':
                D_loss += args.lambda_gp * D_gp_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

        ###################
        # Train Generator #
        ###################

        # Adversarial Loss #
        fake_series = G(noise)
        fake_series = torch.cat(
            (series[:, :, :args.conditional_dim].float(), fake_series.float()),
            dim=2)
        prob_fake = D(fake_series)

        # Calculate Total Generator Loss #
        if args.criterion == 'l2':
            real_labels = torch.ones(prob_fake.size()).to(device)
            G_loss = criterion(prob_fake, real_labels)

        elif args.criterion == 'wgangp':
            G_loss = -torch.mean(prob_fake)

        # Back Propagation and Update #
        G_loss.backward()
        G_optim.step()

        # Add items to Lists #
        D_losses.append(D_loss.item())
        G_losses.append(G_loss.item())

        ####################
        # Print Statistics #
        ####################

        print("Epochs [{}/{}] | D Loss {:.4f} | G Loss {:.4f}".format(
            epoch + 1, args.num_epochs, np.average(D_losses),
            np.average(G_losses)))

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights and Series #
        if (epoch + 1) % args.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    args.weights_path,
                    'TimeSeries_Generator_using{}_Epoch_{}.pkl'.format(
                        args.criterion.upper(), epoch + 1)))

            series, fake_series = generate_fake_samples(
                X, label, G, scaler_1, scaler_2, args, device)
            plot_sample(series, fake_series, epoch, args)
            make_csv(series, fake_series, epoch, args)

    print("Training finished.")
예제 #20
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_horse_loader, train_zebra_loader = get_horse2zebra_loader(
        purpose='train', batch_size=config.batch_size)
    test_horse_loader, test_zebra_loader = get_horse2zebra_loader(
        purpose='test', batch_size=config.val_batch_size)
    total_batch = min(len(train_horse_loader), len(train_zebra_loader))

    # Prepare Networks #
    D_A = Discriminator()
    D_B = Discriminator()
    G_A2B = Generator()
    G_B2A = Generator()

    networks = [D_A, D_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Adversarial = nn.MSELoss()
    criterion_Cycle = nn.L1Loss()
    criterion_Identity = nn.L1Loss()

    # Optimizers #
    D_A_optim = torch.optim.Adam(D_A.parameters(),
                                 lr=config.lr,
                                 betas=(0.5, 0.999))
    D_B_optim = torch.optim.Adam(D_B.parameters(),
                                 lr=config.lr,
                                 betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_A_optim_scheduler = get_lr_scheduler(D_A_optim)
    D_B_optim_scheduler = get_lr_scheduler(D_B_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses_A, D_losses_B, G_losses = [], [], []

    # Training #
    print("Training CycleGAN started with total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):
        for i, (horse,
                zebra) in enumerate(zip(train_horse_loader,
                                        train_zebra_loader)):

            # Data Preparation #
            real_A = horse.to(device)
            real_B = zebra.to(device)

            # Initialize Optimizers #
            G_optim.zero_grad()
            D_A_optim.zero_grad()
            D_B_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B], requires_grad=False)

            # Adversarial Loss #
            fake_A = G_B2A(real_B)
            prob_fake_A = D_A(fake_A)
            real_labels = torch.ones(prob_fake_A.size()).to(device)
            G_mse_loss_B2A = criterion_Adversarial(prob_fake_A, real_labels)

            fake_B = G_A2B(real_A)
            prob_fake_B = D_B(fake_B)
            real_labels = torch.ones(prob_fake_B.size()).to(device)
            G_mse_loss_A2B = criterion_Adversarial(prob_fake_B, real_labels)

            # Identity Loss #
            identity_A = G_B2A(real_A)
            G_identity_loss_A = config.lambda_identity * criterion_Identity(
                identity_A, real_A)

            identity_B = G_A2B(real_B)
            G_identity_loss_B = config.lambda_identity * criterion_Identity(
                identity_B, real_B)

            # Cycle Loss #
            reconstructed_A = G_B2A(fake_B)
            G_cycle_loss_ABA = config.lambda_cycle * criterion_Cycle(
                reconstructed_A, real_A)

            reconstructed_B = G_A2B(fake_A)
            G_cycle_loss_BAB = config.lambda_cycle * criterion_Cycle(
                reconstructed_B, real_B)

            # Calculate Total Generator Loss #
            G_loss = G_mse_loss_B2A + G_mse_loss_A2B + G_identity_loss_A + G_identity_loss_B + G_cycle_loss_ABA + G_cycle_loss_BAB

            # Back Propagation and Update #
            G_loss.backward(retain_graph=True)
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B], requires_grad=True)

            ## Train Discriminator A ##
            # Real Loss #
            prob_real_A = D_A(real_A)
            real_labels = torch.ones(prob_real_A.size()).to(device)
            D_real_loss_A = criterion_Adversarial(prob_real_A, real_labels)

            # Fake Loss #
            prob_fake_A = D_A(fake_A.detach())
            fake_labels = torch.zeros(prob_fake_A.size()).to(device)
            D_fake_loss_A = criterion_Adversarial(prob_fake_A, fake_labels)

            # Calculate Total Discriminator A Loss #
            D_loss_A = config.lambda_identity * (D_real_loss_A +
                                                 D_fake_loss_A).mean()

            # Back propagation and Update #
            D_loss_A.backward(retain_graph=True)
            D_A_optim.step()

            ## Train Discriminator B ##
            # Real Loss #
            prob_real_B = D_B(real_B)
            real_labels = torch.ones(prob_real_B.size()).to(device)
            loss_real_B = criterion_Adversarial(prob_real_B, real_labels)

            # Fake Loss #
            prob_fake_B = D_B(fake_B.detach())
            fake_labels = torch.zeros(prob_fake_B.size()).to(device)
            loss_fake_B = criterion_Adversarial(prob_fake_B, fake_labels)

            # Calculate Total Discriminator B Loss #
            D_loss_B = config.lambda_identity * (loss_real_B +
                                                 loss_fake_B).mean()

            # Back propagation and Update #
            D_loss_B.backward(retain_graph=True)
            D_B_optim.step()

            # Add items to Lists #
            D_losses_A.append(D_loss_A.item())
            D_losses_B.append(D_loss_B.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "CycleGAN | Epoch [{}/{}] | Iterations [{}/{}] | D_A Loss {:.4f} | D_B Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses_A), np.average(D_losses_B),
                            np.average(G_losses)))

                # Save Sample Images #
                sample_images(test_horse_loader, test_zebra_loader, G_A2B,
                              G_B2A, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_A_optim_scheduler.step()
        D_B_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G_A2B.state_dict(),
                os.path.join(
                    config.weights_path,
                    'CycleGAN_Generator_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(
                    config.weights_path,
                    'CycleGAN_Generator_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("CycleGAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses_A, D_losses_B, G_losses, config.num_epochs,
                config.plots_path)

    print("Training finished.")
예제 #21
0
def main(args):

    # Fix Seed #
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Weights and Plots Path #
    paths = [args.weights_path, args.plots_path, args.numpy_path]
    for path in paths:
        make_dirs(path)

    # Prepare Data #
    data = load_data(args.which_data)[[args.feature]]
    data = data.copy()

    # Plot Time-Series Data #
    if args.plot_full:
        plot_full(args.plots_path, data, args.feature)

    scaler = MinMaxScaler()
    data[args.feature] = scaler.fit_transform(data)

    # Split the Dataset #
    copied_data = data.copy().values

    if args.multi_step:
        X, y = split_sequence_multi_step(copied_data, args.seq_length,
                                         args.output_size)
        step = 'MultiStep'
    else:
        X, y = split_sequence_uni_step(copied_data, args.seq_length)
        step = 'SingleStep'

    train_loader, val_loader, test_loader = data_loader(
        X, y, args.train_split, args.test_split, args.batch_size)

    # Lists #
    train_losses, val_losses = list(), list()
    val_maes, val_mses, val_rmses, val_mapes, val_mpes, val_r2s = list(), list(
    ), list(), list(), list(), list()
    test_maes, test_mses, test_rmses, test_mapes, test_mpes, test_r2s = list(
    ), list(), list(), list(), list(), list()
    pred_tests, labels = list(), list()

    # Constants #
    best_val_loss = 100
    best_val_improv = 0

    # Prepare Network #
    if args.model == 'dnn':
        model = DNN(args.seq_length, args.hidden_size,
                    args.output_size).to(device)
    elif args.model == 'cnn':
        model = CNN(args.seq_length, args.batch_size,
                    args.output_size).to(device)
    elif args.model == 'rnn':
        model = RNN(args.input_size, args.hidden_size, args.num_layers,
                    args.output_size).to(device)
    elif args.model == 'lstm':
        model = LSTM(args.input_size, args.hidden_size, args.num_layers,
                     args.output_size, args.bidirectional).to(device)
    elif args.model == 'gru':
        model = GRU(args.input_size, args.hidden_size, args.num_layers,
                    args.output_size).to(device)
    elif args.model == 'attentional':
        model = AttentionalLSTM(args.input_size, args.qkv, args.hidden_size,
                                args.num_layers, args.output_size,
                                args.bidirectional).to(device)
    else:
        raise NotImplementedError

    # Loss Function #
    criterion = torch.nn.MSELoss()

    # Optimizer #
    optim = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             betas=(0.5, 0.999))
    optim_scheduler = get_lr_scheduler(args.lr_scheduler, optim)

    # Train and Validation #
    if args.mode == 'train':

        # Train #
        print("Training {} using {} started with total epoch of {}.".format(
            model.__class__.__name__, step, args.num_epochs))

        for epoch in range(args.num_epochs):
            for i, (data, label) in enumerate(train_loader):

                # Prepare Data #
                data = data.to(device, dtype=torch.float32)
                label = label.to(device, dtype=torch.float32)

                # Forward Data #
                pred = model(data)

                # Calculate Loss #
                train_loss = criterion(pred, label)

                # Initialize Optimizer, Back Propagation and Update #
                optim.zero_grad()
                train_loss.backward()
                optim.step()

                # Add item to Lists #
                train_losses.append(train_loss.item())

            # Print Statistics #
            if (epoch + 1) % args.print_every == 0:
                print("Epoch [{}/{}]".format(epoch + 1, args.num_epochs))
                print("Train Loss {:.4f}".format(np.average(train_losses)))

            # Learning Rate Scheduler #
            optim_scheduler.step()

            # Validation #
            with torch.no_grad():
                for i, (data, label) in enumerate(val_loader):

                    # Prepare Data #
                    data = data.to(device, dtype=torch.float32)
                    label = label.to(device, dtype=torch.float32)

                    # Forward Data #
                    pred_val = model(data)

                    # Calculate Loss #
                    val_loss = criterion(pred_val, label)

                    if args.multi_step:
                        pred_val = np.mean(pred_val.detach().cpu().numpy(),
                                           axis=1)
                        label = np.mean(label.detach().cpu().numpy(), axis=1)
                    else:
                        pred_val, label = pred_val.cpu(), label.cpu()

                    # Calculate Metrics #
                    val_mae = mean_absolute_error(label, pred_val)
                    val_mse = mean_squared_error(label, pred_val, squared=True)
                    val_rmse = mean_squared_error(label,
                                                  pred_val,
                                                  squared=False)
                    val_mpe = mean_percentage_error(label, pred_val)
                    val_mape = mean_absolute_percentage_error(label, pred_val)
                    val_r2 = r2_score(label, pred_val)

                    # Add item to Lists #
                    val_losses.append(val_loss.item())
                    val_maes.append(val_mae.item())
                    val_mses.append(val_mse.item())
                    val_rmses.append(val_rmse.item())
                    val_mpes.append(val_mpe.item())
                    val_mapes.append(val_mape.item())
                    val_r2s.append(val_r2.item())

            if (epoch + 1) % args.print_every == 0:

                # Print Statistics #
                print("Val Loss {:.4f}".format(np.average(val_losses)))
                print(" MAE : {:.4f}".format(np.average(val_maes)))
                print(" MSE : {:.4f}".format(np.average(val_mses)))
                print("RMSE : {:.4f}".format(np.average(val_rmses)))
                print(" MPE : {:.4f}".format(np.average(val_mpes)))
                print("MAPE : {:.4f}".format(np.average(val_mapes)))
                print(" R^2 : {:.4f}".format(np.average(val_r2s)))

                # Save the model only if validation loss decreased #
                curr_val_loss = np.average(val_losses)

                if curr_val_loss < best_val_loss:
                    best_val_loss = min(curr_val_loss, best_val_loss)
                    torch.save(
                        model.state_dict(),
                        os.path.join(
                            args.weights_path, 'BEST_{}_using_{}.pkl'.format(
                                model.__class__.__name__, step)))

                    print("Best model is saved!\n")
                    best_val_improv = 0

                elif curr_val_loss >= best_val_loss:
                    best_val_improv += 1
                    print("Best Validation has not improved for {} epochs.\n".
                          format(best_val_improv))

    elif args.mode == 'test':

        # Load the Model Weight #
        model.load_state_dict(
            torch.load(
                os.path.join(
                    args.weights_path,
                    'BEST_{}_using_{}.pkl'.format(model.__class__.__name__,
                                                  step))))

        # Test #
        with torch.no_grad():
            for i, (data, label) in enumerate(test_loader):

                # Prepare Data #
                data = data.to(device, dtype=torch.float32)
                label = label.to(device, dtype=torch.float32)

                # Forward Data #
                pred_test = model(data)

                # Convert to Original Value Range #
                pred_test, label = pred_test.detach().cpu().numpy(
                ), label.detach().cpu().numpy()

                pred_test = scaler.inverse_transform(pred_test)
                label = scaler.inverse_transform(label)

                if args.multi_step:
                    pred_test = np.mean(pred_test, axis=1)
                    label = np.mean(label, axis=1)

                pred_tests += pred_test.tolist()
                labels += label.tolist()

                # Calculate Loss #
                test_mae = mean_absolute_error(label, pred_test)
                test_mse = mean_squared_error(label, pred_test, squared=True)
                test_rmse = mean_squared_error(label, pred_test, squared=False)
                test_mpe = mean_percentage_error(label, pred_test)
                test_mape = mean_absolute_percentage_error(label, pred_test)
                test_r2 = r2_score(label, pred_test)

                # Add item to Lists #
                test_maes.append(test_mae.item())
                test_mses.append(test_mse.item())
                test_rmses.append(test_rmse.item())
                test_mpes.append(test_mpe.item())
                test_mapes.append(test_mape.item())
                test_r2s.append(test_r2.item())

            # Print Statistics #
            print("Test {} using {}".format(model.__class__.__name__, step))
            print(" MAE : {:.4f}".format(np.average(test_maes)))
            print(" MSE : {:.4f}".format(np.average(test_mses)))
            print("RMSE : {:.4f}".format(np.average(test_rmses)))
            print(" MPE : {:.4f}".format(np.average(test_mpes)))
            print("MAPE : {:.4f}".format(np.average(test_mapes)))
            print(" R^2 : {:.4f}".format(np.average(test_r2s)))

            # Plot Figure #
            plot_pred_test(pred_tests[:args.time_plot],
                           labels[:args.time_plot], args.plots_path,
                           args.feature, model, step)

            # Save Numpy files #
            np.save(
                os.path.join(
                    args.numpy_path,
                    '{}_using_{}_TestSet.npy'.format(model.__class__.__name__,
                                                     step)),
                np.asarray(pred_tests))
            np.save(
                os.path.join(args.numpy_path,
                             'TestSet_using_{}.npy'.format(step)),
                np.asarray(labels))

    else:
        raise NotImplementedError
예제 #22
0
def main():
    print("Device is ", DEVICE)
    model = MODEL_DISPATCHER[BASE_MODEL](pretrain=True)
    # model.load_state_dict(torch.load("../se_net/20200307-154303/weights/best_se_net_fold4_model_3_macro_recall=0.9435.pth"))
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)

    model.to(DEVICE)
    print("Model loaded !!! ")

    exp_name = datetime.now().strftime("%Y%m%d-%H%M%S")
    if not os.path.exists(os.path.join("../", BASE_MODEL)):
        os.mkdir(os.path.join("../", BASE_MODEL))
    OUT_DIR = os.path.join("../", BASE_MODEL, exp_name)
    print("This Exp would be save in ", OUT_DIR)

    os.mkdir(OUT_DIR)

    os.mkdir(os.path.join(OUT_DIR, "weights"))

    os.mkdir(os.path.join(OUT_DIR, "log"))

    train_dataset = BengaliDatasetTrain(folds=TRAINING_FOLDS,
                                        img_height=IMG_HEIGHT,
                                        img_width=IMG_WIDTH,
                                        mean=MODEL_MEAN,
                                        std=MODEL_STD)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=TRAINING_BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=4,
    )
    valid_dataset = BengaliDatasetTrain(folds=VAL_FOLDS,
                                        img_height=IMG_HEIGHT,
                                        img_width=IMG_WIDTH,
                                        mean=MODEL_MEAN,
                                        std=MODEL_STD)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                               batch_size=TEST_BATCH_SIZE,
                                               shuffle=False,
                                               num_workers=4,
                                               pin_memory=True)

    optimizer = get_optimizer(model, parameters.get("momentum"),
                              parameters.get("weight_decay"),
                              parameters.get("nesterov"))
    lr_scheduler = get_lr_scheduler(optimizer,
                                    parameters.get("lr_max_value"),
                                    parameters.get("lr_max_value_epoch"),
                                    num_epochs=EPOCH,
                                    epoch_length=len(train_loader))

    ## Define Trainer
    trainer = create_trainer(model, optimizer, DEVICE, WEIGHT_ONE, WEIGHT_TWO,
                             WEIGHT_THR)

    # Recall for Training
    EpochMetric(compute_fn=macro_recall,
                output_transform=output_transform).attach(trainer, 'recall')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names='all')

    evaluator = create_evaluator(model, DEVICE)

    #Recall for evaluating
    EpochMetric(compute_fn=macro_recall,
                output_transform=output_transform).attach(evaluator, 'recall')

    def run_evaluator(engine):
        evaluator.run(valid_loader)

    def get_curr_lr(engine):
        lr = lr_scheduler.schedulers[0].optimizer.param_groups[0]['lr']
        log_report.report('lr', lr)

    def score_fn(engine):
        score = engine.state.metrics['loss']
        return score

    es_handler = EarlyStopping(patience=30,
                               score_function=score_fn,
                               trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, es_handler)

    def default_score_fn(engine):
        score = engine.state.metrics['recall']
        return score

    trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
    best_model_handler = ModelCheckpoint(
        dirname=os.path.join(OUT_DIR, "weights"),
        filename_prefix=f"best_{BASE_MODEL}_fold{VAL_FOLDS[0]}",
        n_saved=3,
        global_step_transform=global_step_from_engine(trainer),
        score_name="macro_recall",
        score_function=default_score_fn)
    evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {
        "model": model,
    })

    trainer.add_event_handler(Events.EPOCH_COMPLETED, run_evaluator)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, get_curr_lr)

    log_report = LogReport(evaluator, os.path.join(OUT_DIR, "log"))

    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_report)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        ModelSnapshotHandler(model,
                             filepath=os.path.join(
                                 OUT_DIR, "weights", "{}_fold{}.pth".format(
                                     BASE_MODEL, VAL_FOLDS[0]))))

    trainer.run(train_loader, max_epochs=EPOCH)

    train_history = log_report.get_dataframe()
    train_history.to_csv(os.path.join(
        OUT_DIR, "log", "{}_fold{}_log.csv".format(BASE_MODEL, VAL_FOLDS[0])),
                         index=False)

    print(train_history.head())
    print("Trainning Done !!!")
예제 #23
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_facades_loader('train', config.batch_size)
    val_loader = get_facades_loader('val', config.val_batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    D = Discriminator().to(device)
    G = Generator().to(device)

    # Criterion #
    criterion_Adversarial = nn.BCELoss()
    criterion_Pixelwise = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(D.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, G_losses = [], []

    # Training #
    print("Training Pix2Pix started with total epoch of {}.".format(
        config.num_epochs))

    for epoch in range(config.num_epochs):

        for i, (real_A, real_B) in enumerate(train_loader):

            # Data Preparation #
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            ###################
            # Train Generator #
            ###################

            # Prevent Discriminator Update during Generator Update #
            set_requires_grad(D, requires_grad=False)

            # Adversarial Loss #
            fake_B = G(real_A)
            prob_fake = D(fake_B, real_A)
            real_labels = torch.ones(prob_fake.size()).to(device)
            G_loss_fake = criterion_Adversarial(prob_fake, real_labels)

            # Pixel-Wise Loss #
            G_loss_pixelwise = criterion_Pixelwise(fake_B, real_B)

            # Calculate Total Generator Loss #
            G_loss = G_loss_fake + config.l1_lambda * G_loss_pixelwise

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            #######################
            # Train Discriminator #
            #######################

            # Prevent Discriminator Update during Generator Update #
            set_requires_grad(D, requires_grad=True)

            # Adversarial Loss #
            prob_real = D(real_B, real_A)
            real_labels = torch.ones(prob_real.size()).to(device)
            D_real_loss = criterion_Adversarial(prob_real, real_labels)

            fake_B = G(real_A)
            prob_fake = D(fake_B.detach(), real_A)
            fake_labels = torch.zeros(prob_fake.size()).to(device)
            D_fake_loss = criterion_Adversarial(prob_fake, fake_labels)

            # Calculate Total Discriminator Loss #
            D_loss = torch.mean(D_real_loss + D_fake_loss)

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "Pix2Pix | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G, epoch, config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    config.weights_path,
                    'Pix2Pix_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train('Pix2Pix', config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.plots_path)

    print("Training finished.")
예제 #24
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights, and Plots Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader_selfie, train_loader_anime = get_selfie2anime_loader(
        'train', config.batch_size)
    total_batch = max(len(train_loader_selfie), len(train_loader_anime))

    test_loader_selfie, test_loader_anime = get_selfie2anime_loader(
        'test', config.val_batch_size)

    # Prepare Networks #
    D_A = Discriminator(num_layers=7)
    D_B = Discriminator(num_layers=7)
    L_A = Discriminator(num_layers=5)
    L_B = Discriminator(num_layers=5)
    G_A2B = Generator(image_size=config.crop_size,
                      num_blocks=config.num_blocks)
    G_B2A = Generator(image_size=config.crop_size,
                      num_blocks=config.num_blocks)

    networks = [D_A, D_B, L_A, L_B, G_A2B, G_B2A]

    for network in networks:
        network.to(device)

    # Loss Function #
    Adversarial_loss = nn.MSELoss()
    Cycle_loss = nn.L1Loss()
    BCE_loss = nn.BCEWithLogitsLoss()

    # Optimizers #
    D_optim = torch.optim.Adam(chain(D_A.parameters(), D_B.parameters(),
                                     L_A.parameters(), L_B.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.0001)
    G_optim = torch.optim.Adam(chain(G_A2B.parameters(), G_B2A.parameters()),
                               lr=config.lr,
                               betas=(0.5, 0.999),
                               weight_decay=0.0001)

    D_optim_scheduler = get_lr_scheduler(D_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Rho Clipper to constraint the value of rho in AdaILN and ILN #
    Rho_Clipper = RhoClipper(0, 1)

    # Lists #
    D_losses = []
    G_losses = []

    # Train #
    print("Training U-GAT-IT started with total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):

        for i, (selfie, anime) in enumerate(
                zip(train_loader_selfie, train_loader_anime)):

            # Data Preparation #
            real_A = selfie.to(device)
            real_B = anime.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=True)

            # Forward Data #
            fake_B, _, _ = G_A2B(real_A)
            fake_A, _, _ = G_B2A(real_B)

            G_real_A, G_real_A_cam, _ = D_A(real_A)
            L_real_A, L_real_A_cam, _ = L_A(real_A)
            G_real_B, G_real_B_cam, _ = D_B(real_B)
            L_real_B, L_real_B_cam, _ = L_B(real_B)

            G_fake_A, G_fake_A_cam, _ = D_A(fake_A)
            L_fake_A, L_fake_A_cam, _ = L_A(fake_A)
            G_fake_B, G_fake_B_cam, _ = D_B(fake_B)
            L_fake_B, L_fake_B_cam, _ = L_B(fake_B)

            # Adversarial Loss of Discriminator #
            real_labels = torch.ones(G_real_A.shape).to(device)
            D_ad_real_loss_GA = Adversarial_loss(G_real_A, real_labels)

            fake_labels = torch.zeros(G_fake_A.shape).to(device)
            D_ad_fake_loss_GA = Adversarial_loss(G_fake_A, fake_labels)

            D_ad_loss_GA = D_ad_real_loss_GA + D_ad_fake_loss_GA

            real_labels = torch.ones(G_real_A_cam.shape).to(device)
            D_ad_cam_real_loss_GA = Adversarial_loss(G_real_A_cam, real_labels)

            fake_labels = torch.zeros(G_fake_A_cam.shape).to(device)
            D_ad_cam_fake_loss_GA = Adversarial_loss(G_fake_A_cam, fake_labels)

            D_ad_cam_loss_GA = D_ad_cam_real_loss_GA + D_ad_cam_fake_loss_GA

            real_labels = torch.ones(G_real_B.shape).to(device)
            D_ad_real_loss_GB = Adversarial_loss(G_real_B, real_labels)

            fake_labels = torch.zeros(G_fake_B.shape).to(device)
            D_ad_fake_loss_GB = Adversarial_loss(G_fake_B, fake_labels)

            D_ad_loss_GB = D_ad_real_loss_GB + D_ad_fake_loss_GB

            real_labels = torch.ones(G_real_B_cam.shape).to(device)
            D_ad_cam_real_loss_GB = Adversarial_loss(G_real_B_cam, real_labels)

            fake_labels = torch.zeros(G_fake_B_cam.shape).to(device)
            D_ad_cam_fake_loss_GB = Adversarial_loss(G_fake_B_cam, fake_labels)

            D_ad_cam_loss_GB = D_ad_cam_real_loss_GB + D_ad_cam_fake_loss_GB

            # Adversarial Loss of L #
            real_labels = torch.ones(L_real_A.shape).to(device)
            D_ad_real_loss_LA = Adversarial_loss(L_real_A, real_labels)

            fake_labels = torch.zeros(L_fake_A.shape).to(device)
            D_ad_fake_loss_LA = Adversarial_loss(L_fake_A, fake_labels)

            D_ad_loss_LA = D_ad_real_loss_LA + D_ad_fake_loss_LA

            real_labels = torch.ones(L_real_A_cam.shape).to(device)
            D_ad_cam_real_loss_LA = Adversarial_loss(L_real_A_cam, real_labels)

            fake_labels = torch.zeros(L_fake_A_cam.shape).to(device)
            D_ad_cam_fake_loss_LA = Adversarial_loss(L_fake_A_cam, fake_labels)

            D_ad_cam_loss_LA = D_ad_cam_real_loss_LA + D_ad_cam_fake_loss_LA

            real_labels = torch.ones(L_real_B.shape).to(device)
            D_ad_real_loss_LB = Adversarial_loss(L_real_B, real_labels)

            fake_labels = torch.zeros(L_fake_B.shape).to(device)
            D_ad_fake_loss_LB = Adversarial_loss(L_fake_B, fake_labels)

            D_ad_loss_LB = D_ad_real_loss_LB + D_ad_fake_loss_LB

            real_labels = torch.ones(L_real_B_cam.shape).to(device)
            D_ad_cam_real_loss_LB = Adversarial_loss(L_real_B_cam, real_labels)

            fake_labels = torch.zeros(L_fake_B_cam.shape).to(device)
            D_ad_cam_fake_loss_LB = Adversarial_loss(L_fake_B_cam, fake_labels)

            D_ad_cam_loss_LB = D_ad_cam_real_loss_LB + D_ad_cam_fake_loss_LB

            # Calculate Each Discriminator Loss #
            D_loss_A = D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA
            D_loss_B = D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB

            # Calculate Total Discriminator Loss #
            D_loss = D_loss_A + D_loss_B

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            set_requires_grad([D_A, D_B, L_A, L_B], requires_grad=False)

            # Forward Data #
            fake_B, fake_B_cam, _ = G_A2B(real_A)
            fake_A, fake_A_cam, _ = G_B2A(real_B)

            fake_ABA, _, _ = G_B2A(fake_B)
            fake_BAB, _, _ = G_A2B(fake_A)

            fake_A2A, fake_A2A_cam, _ = G_A2B(real_A)
            fake_B2B, fake_B2B_cam, _ = G_B2A(real_B)

            G_fake_A, G_fake_A_cam, _ = D_A(fake_A)
            L_fake_A, L_fake_A_cam, _ = L_A(fake_A)
            G_fake_B, G_fake_B_cam, _ = D_B(fake_B)
            L_fake_B, L_fake_B_cam, _ = L_B(fake_B)

            # Adversarial Loss of Generator #
            real_labels = torch.ones(G_fake_A.shape).to(device)
            G_adv_fake_loss_A = Adversarial_loss(G_fake_A, real_labels)

            real_labels = torch.ones(G_fake_A_cam.shape).to(device)
            G_adv_cam_fake_loss_A = Adversarial_loss(G_fake_A_cam, real_labels)

            G_adv_loss_A = G_adv_fake_loss_A + G_adv_cam_fake_loss_A

            real_labels = torch.ones(G_fake_B.shape).to(device)
            G_adv_fake_loss_B = Adversarial_loss(G_fake_B, real_labels)

            real_labels = torch.ones(G_fake_B_cam.shape).to(device)
            G_adv_cam_fake_loss_B = Adversarial_loss(G_fake_B_cam, real_labels)

            G_adv_loss_B = G_adv_fake_loss_B + G_adv_cam_fake_loss_B

            # Adversarial Loss of L #
            real_labels = torch.ones(L_fake_A.shape).to(device)
            L_adv_fake_loss_A = Adversarial_loss(L_fake_A, real_labels)

            real_labels = torch.ones(L_fake_A_cam.shape).to(device)
            L_adv_cam_fake_loss_A = Adversarial_loss(L_fake_A_cam, real_labels)

            L_adv_loss_A = L_adv_fake_loss_A + L_adv_cam_fake_loss_A

            real_labels = torch.ones(L_fake_B.shape).to(device)
            L_adv_fake_loss_B = Adversarial_loss(L_fake_B, real_labels)

            real_labels = torch.ones(L_fake_B_cam.shape).to(device)
            L_adv_cam_fake_loss_B = Adversarial_loss(L_fake_B_cam, real_labels)

            L_adv_loss_B = L_adv_fake_loss_B + L_adv_cam_fake_loss_B

            # Cycle Consistency Loss #
            G_recon_loss_A = Cycle_loss(fake_ABA, real_A)
            G_recon_loss_B = Cycle_loss(fake_BAB, real_B)

            G_identity_loss_A = Cycle_loss(fake_A2A, real_A)
            G_identity_loss_B = Cycle_loss(fake_B2B, real_B)

            G_cycle_loss_A = G_recon_loss_A + G_identity_loss_A
            G_cycle_loss_B = G_recon_loss_B + G_identity_loss_B

            # CAM Loss #
            real_labels = torch.ones(fake_A_cam.shape).to(device)
            G_cam_real_loss_A = BCE_loss(fake_A_cam, real_labels)

            fake_labels = torch.zeros(fake_A2A_cam.shape).to(device)
            G_cam_fake_loss_A = BCE_loss(fake_A2A_cam, fake_labels)

            G_cam_loss_A = G_cam_real_loss_A + G_cam_fake_loss_A

            real_labels = torch.ones(fake_B_cam.shape).to(device)
            G_cam_real_loss_B = BCE_loss(fake_B_cam, real_labels)

            fake_labels = torch.zeros(fake_B2B_cam.shape).to(device)
            G_cam_fake_loss_B = BCE_loss(fake_B2B_cam, fake_labels)

            G_cam_loss_B = G_cam_real_loss_B + G_cam_fake_loss_B

            # Calculate Each Generator Loss #
            G_loss_A = G_adv_loss_A + L_adv_loss_A + config.lambda_cycle * G_cycle_loss_A + config.lambda_cam * G_cam_loss_A
            G_loss_B = G_adv_loss_B + L_adv_loss_B + config.lambda_cycle * G_cycle_loss_B + config.lambda_cam * G_cam_loss_B

            # Calculate Total Generator Loss #
            G_loss = G_loss_A + G_loss_B

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Apply Rho Clipper to Generators #
            G_A2B.apply(Rho_Clipper)
            G_B2A.apply(Rho_Clipper)

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "U-GAT-IT | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(G_losses)))

                # Save Sample Images #
                save_samples(test_loader_selfie, G_A2B, epoch,
                             config.samples_path)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                D_A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_D_A_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                D_B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_D_B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                L_A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_L_A_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                L_B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_L_B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_A2B.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_G_A2B_Epoch_{}.pkl'.format(epoch + 1)))
            torch.save(
                G_B2A.state_dict(),
                os.path.join(config.weights_path,
                             'U-GAT-IT_G_B2A_Epoch_{}.pkl'.format(epoch + 1)))

    # Plot Losses #
    plot_losses(D_losses, G_losses, config.num_epochs, config.plots_path)

    # Make a GIF file #
    make_gifs_train('U-GAT-IT', config.samples_path)

    print("Training finished.")
def train_srgans(train_loader, val_loader, generator, discriminator, device,
                 args):

    # Loss Function #
    criterion_Perceptual = PerceptualLoss(args.model).to(device)

    # For SRGAN #
    criterion_MSE = nn.MSELoss()
    criterion_TV = TVLoss()

    # For ESRGAN #
    criterion_BCE = nn.BCEWithLogitsLoss()
    criterion_Content = nn.L1Loss()

    # Optimizers #
    D_optim = torch.optim.Adam(discriminator.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999))
    G_optim = torch.optim.Adam(generator.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999))

    D_optim_scheduler = get_lr_scheduler(D_optim, args)
    G_optim_scheduler = get_lr_scheduler(G_optim, args)

    # Lists #
    D_losses, G_losses = list(), list()

    # Train #
    print("Training {} started with total epoch of {}.".format(
        str(args.model).upper(), args.num_epochs))

    for epoch in range(args.num_epochs):
        for i, (high, low) in enumerate(train_loader):

            discriminator.train()
            if args.model == "srgan":
                generator.train()

            # Data Preparation #
            high = high.to(device)
            low = low.to(device)

            # Initialize Optimizers #
            D_optim.zero_grad()
            G_optim.zero_grad()

            #######################
            # Train Discriminator #
            #######################

            set_requires_grad(discriminator, requires_grad=True)

            # Generate Fake HR Images #
            fake_high = generator(low)

            if args.model == 'srgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high.detach())

                # Calculate Total Discriminator Loss #
                D_loss = 1 - prob_real.mean() + prob_fake.mean()

            elif args.model == 'esrgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high.detach())

                # Relativistic Discriminator #
                diff_r2f = prob_real - prob_fake.mean()
                diff_f2r = prob_fake - prob_real.mean()

                # Labels #
                real_labels = torch.ones(diff_r2f.size()).to(device)
                fake_labels = torch.zeros(diff_f2r.size()).to(device)

                # Adversarial Loss #
                D_loss_real = criterion_BCE(diff_r2f, real_labels)
                D_loss_fake = criterion_BCE(diff_f2r, fake_labels)

                # Calculate Total Discriminator Loss #
                D_loss = (D_loss_real + D_loss_fake).mean()

            # Back Propagation and Update #
            D_loss.backward()
            D_optim.step()

            ###################
            # Train Generator #
            ###################

            set_requires_grad(discriminator, requires_grad=False)

            if args.model == 'srgan':

                # Adversarial Loss #
                prob_fake = discriminator(fake_high).mean()
                G_loss_adversarial = torch.mean(1 - prob_fake)
                G_loss_mse = criterion_MSE(fake_high, high)

                # Perceptual Loss #
                lambda_perceptual = 6e-3
                G_loss_perceptual = criterion_Perceptual(fake_high, high)

                # Total Variation Loss #
                G_loss_tv = criterion_TV(fake_high)

                # Calculate Total Generator Loss #
                G_loss = args.lambda_adversarial * G_loss_adversarial + G_loss_mse + lambda_perceptual * G_loss_perceptual + args.lambda_tv * G_loss_tv

            elif args.model == 'esrgan':

                # Forward Data #
                prob_real = discriminator(high)
                prob_fake = discriminator(fake_high)

                # Relativistic Discriminator #
                diff_r2f = prob_real - prob_fake.mean()
                diff_f2r = prob_fake - prob_real.mean()

                # Labels #
                real_labels = torch.ones(diff_r2f.size()).to(device)
                fake_labels = torch.zeros(diff_f2r.size()).to(device)

                # Adversarial Loss #
                G_loss_bce_real = criterion_BCE(diff_f2r, real_labels)
                G_loss_bce_fake = criterion_BCE(diff_r2f, fake_labels)

                G_loss_bce = (G_loss_bce_real + G_loss_bce_fake).mean()

                # Perceptual Loss #
                lambda_perceptual = 1e-2
                G_loss_perceptual = criterion_Perceptual(fake_high, high)

                # Content Loss #
                G_loss_content = criterion_Content(fake_high, high)

                # Calculate Total Generator Loss #
                G_loss = args.lambda_bce * G_loss_bce + lambda_perceptual * G_loss_perceptual + args.lambda_content * G_loss_content

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % args.print_every == 0:
                print(
                    "{} | Epoch [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | G Loss {:.4f}"
                    .format(
                        str(args.model).upper(), epoch + 1, args.num_epochs,
                        i + 1, len(train_loader), np.average(D_losses),
                        np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, args.batch_size, args.scale_factor,
                              generator, epoch, args.samples_path, device)

        # Adjust Learning Rate #
        D_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Model Weights and Inference #
        if (epoch + 1) % args.save_every == 0:
            torch.save(
                generator.state_dict(),
                os.path.join(
                    args.weights_path,
                    '{}_Epoch_{}.pkl'.format(generator.__class__.__name__,
                                             epoch + 1)))
            inference(val_loader, generator, args.upscale_factor, epoch,
                      args.inference_path, device)
예제 #26
0
def train_srcnns(train_loader, val_loader, model, device, args):

    # Loss Function #
    criterion = nn.L1Loss()

    # Optimizers #
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=(0.5, 0.999))
    optimizer_scheduler = get_lr_scheduler(optimizer=optimizer, args=args)

    # Lists #
    losses = list()

    # Train #
    print("Training {} started with total epoch of {}.".format(
        str(args.model).upper(), args.num_epochs))

    for epoch in range(args.num_epochs):
        for i, (high, low) in enumerate(train_loader):

            # Data Preparation #
            high = high.to(device)
            low = low.to(device)

            # Forward Data #
            generated = model(low)

            # Calculate Loss #
            loss = criterion(generated, high)

            # Initialize Optimizer #
            optimizer.zero_grad()

            # Back Propagation and Update #
            loss.backward()
            optimizer.step()

            # Add items to Lists #
            losses.append(loss.item())

            # Print Statistics #
            if (i + 1) % args.print_every == 0:
                print("{} | Epoch [{}/{}] | Iterations [{}/{}] | Loss {:.4f}".
                      format(
                          str(args.model).upper(), epoch + 1, args.num_epochs,
                          i + 1, len(train_loader), np.average(losses)))

                # Save Sample Images #
                sample_images(val_loader, args.batch_size, args.upscale_factor,
                              model, epoch, args.samples_path, device)

        # Adjust Learning Rate #
        optimizer_scheduler.step()

        # Save Model Weights and Inference #
        if (epoch + 1) % args.save_every == 0:
            torch.save(
                model.state_dict(),
                os.path.join(
                    args.weights_path,
                    '{}_Epoch_{}.pkl'.format(model.__class__.__name__,
                                             epoch + 1)))
            inference(val_loader, model, args.upscale_factor, epoch,
                      args.inference_path, device)
예제 #27
0
def train(opt, train_iter, test_iter, verbose=True):
    global_start = time.time()
    logger = utils.getLogger()
    model = models.setup(opt)
    if torch.cuda.is_available():
        model.cuda()
    params = [param for param in model.parameters() if param.requires_grad
              ]  #filter(lambda p: p.requires_grad, model.parameters())

    model_info = ";".join([
        str(k) + ":" + str(v) for k, v in opt.__dict__.items()
        if type(v) in (str, int, float, list, bool)
    ])
    logger.info("# parameters:" + str(sum(param.numel() for param in params)))
    logger.info(model_info)

    model.train()
    optimizer = utils.getOptimizer(params,
                                   name=opt.optimizer,
                                   lr=opt.learning_rate,
                                   scheduler=utils.get_lr_scheduler(
                                       opt.lr_scheduler))

    loss_fun = F.cross_entropy

    filename = None
    percisions = []
    for i in range(opt.max_epoch):
        for epoch, batch in enumerate(train_iter):
            optimizer.zero_grad()
            start = time.time()

            text = batch.text[0] if opt.from_torchtext else batch.text
            predicted = model(text)

            loss = loss_fun(predicted, batch.label)

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()

            if verbose:
                if torch.cuda.is_available():
                    logger.info(
                        "%d iteration %d epoch with loss : %.5f in %.4f seconds"
                        % (i, epoch, loss.cpu().data.numpy(),
                           time.time() - start))
                else:
                    logger.info(
                        "%d iteration %d epoch with loss : %.5f in %.4f seconds"
                        %
                        (i, epoch, loss.data.numpy()[0], time.time() - start))

        percision = utils.evaluation(model, test_iter, opt.from_torchtext)
        if verbose:
            logger.info("%d iteration with percision %.4f" % (i, percision))
        if len(percisions) == 0 or percision > max(percisions):
            if filename:
                os.remove(filename)
            filename = model.save(metric=percision)
        percisions.append(percision)


#    while(utils.is_writeable(performance_log_file)):
    df = pd.read_csv(performance_log_file, index_col=0, sep="\t")
    df.loc[model_info, opt.dataset] = max(percisions)
    df.to_csv(performance_log_file, sep="\t")
    logger.info(model_info + " with time :" + str(time.time() - global_start) +
                " ->" + str(max(percisions)))
    print(model_info + " with time :" + str(time.time() - global_start) +
          " ->" + str(max(percisions)))
예제 #28
0
def train():

    # Fix Seed for Reproducibility #
    torch.manual_seed(9)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(9)

    # Samples, Weights and Results Path #
    paths = [config.samples_path, config.weights_path, config.plots_path]
    paths = [make_dirs(path) for path in paths]

    # Prepare Data Loader #
    train_loader = get_edges2handbags_loader(purpose='train',
                                             batch_size=config.batch_size)
    val_loader = get_edges2handbags_loader(purpose='val',
                                           batch_size=config.batch_size)
    total_batch = len(train_loader)

    # Prepare Networks #
    D_cVAE = Discriminator()
    D_cLR = Discriminator()
    E = Encoder(config.z_dim)
    G = Generator(config.z_dim)

    networks = [D_cVAE, D_cLR, E, G]
    for network in networks:
        network.to(device)

    # Loss Function #
    criterion_Recon = nn.L1Loss()
    criterion_Adversarial = nn.MSELoss()

    # Optimizers #
    D_cVAE_optim = torch.optim.Adam(D_cVAE.parameters(),
                                    lr=config.lr,
                                    betas=(0.5, 0.999))
    D_cLR_optim = torch.optim.Adam(D_cLR.parameters(),
                                   lr=config.lr,
                                   betas=(0.5, 0.999))
    E_optim = torch.optim.Adam(E.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))
    G_optim = torch.optim.Adam(G.parameters(),
                               lr=config.lr,
                               betas=(0.5, 0.999))

    D_cVAE_optim_scheduler = get_lr_scheduler(D_cVAE_optim)
    D_cLR_optim_scheduler = get_lr_scheduler(D_cLR_optim)
    E_optim_scheduler = get_lr_scheduler(E_optim)
    G_optim_scheduler = get_lr_scheduler(G_optim)

    # Lists #
    D_losses, E_G_losses, G_losses = [], [], []

    # Fixed Noise #
    fixed_noise = torch.randn(config.test_size, config.num_images,
                              config.z_dim).to(device)

    # Training #
    print("Training BicycleGAN started total epoch of {}.".format(
        config.num_epochs))
    for epoch in range(config.num_epochs):
        for i, (sketch, target) in enumerate(train_loader):

            # Data Preparation #
            sketch = sketch.to(device)
            target = target.to(device)

            # Separate Data for D_cVAE-GAN and D_cLR-GAN #
            cVAE_data = {
                'sketch': sketch[0].unsqueeze(dim=0),
                'target': target[0].unsqueeze(dim=0)
            }
            cLR_data = {
                'sketch': sketch[1].unsqueeze(dim=0),
                'target': target[1].unsqueeze(dim=0)
            }

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Train Discriminators #
            set_requires_grad([D_cVAE, D_cLR], requires_grad=True)

            ################################
            # Train Discriminator cVAE-GAN #
            ################################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Encode Latent Vector #
            mean, std = E(cVAE_data['target'])
            random_z = torch.randn(1, config.z_dim).to(device)
            encoded_z = mean + (random_z * std)

            # Generate Fake Image #
            fake_image_cVAE = G(cVAE_data['sketch'], encoded_z)

            # Forward to Discriminator cVAE-GAN #
            prob_real_D_cVAE_1, prob_real_D_cVAE_2 = D_cVAE(
                cVAE_data['target'])
            prob_fake_D_cVAE_1, prob_fake_D_cVAE_2 = D_cVAE(
                fake_image_cVAE.detach())

            # Adversarial Loss using cVAE_1 #
            real_labels = torch.ones(prob_real_D_cVAE_1.size()).to(device)
            D_cVAE_1_real_loss = criterion_Adversarial(prob_real_D_cVAE_1,
                                                       real_labels)

            fake_labels = torch.zeros(prob_fake_D_cVAE_1.size()).to(device)
            D_cVAE_1_fake_loss = criterion_Adversarial(prob_fake_D_cVAE_1,
                                                       fake_labels)

            D_cVAE_1_loss = D_cVAE_1_real_loss + D_cVAE_1_fake_loss

            # Adversarial Loss using cVAE_2 #
            real_labels = torch.ones(prob_real_D_cVAE_2.size()).to(device)
            D_cVAE_2_real_loss = criterion_Adversarial(prob_real_D_cVAE_2,
                                                       real_labels)

            fake_labels = torch.zeros(prob_fake_D_cVAE_2.size()).to(device)
            D_cVAE_2_fake_loss = criterion_Adversarial(prob_fake_D_cVAE_2,
                                                       fake_labels)

            D_cVAE_2_loss = D_cVAE_2_real_loss + D_cVAE_2_fake_loss

            ###########################
            # Train Discriminator cLR #
            ###########################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)

            # Forward to Discriminator cLR-GAN #
            prob_real_D_cLR_1, prob_real_D_cLR_2 = D_cLR(cLR_data['target'])
            prob_fake_D_cLR_1, prob_fake_D_cLR_2 = D_cLR(
                fake_image_cLR.detach())

            # Adversarial Loss using cLR-1 #
            real_labels = torch.ones(prob_real_D_cLR_1.size()).to(device)
            D_cLR_1_real_loss = criterion_Adversarial(prob_real_D_cLR_1,
                                                      real_labels)

            fake_labels = torch.zeros(prob_fake_D_cLR_1.size()).to(device)
            D_cLR_1_fake_loss = criterion_Adversarial(prob_fake_D_cLR_1,
                                                      fake_labels)

            D_cLR_1_loss = D_cLR_1_real_loss + D_cLR_1_fake_loss

            # Adversarial Loss using cLR-2 #
            real_labels = torch.ones(prob_real_D_cLR_2.size()).to(device)
            D_cLR_2_real_loss = criterion_Adversarial(prob_real_D_cLR_2,
                                                      real_labels)

            fake_labels = torch.zeros(prob_fake_D_cLR_2.size()).to(device)
            D_cLR_2_fake_loss = criterion_Adversarial(prob_fake_D_cLR_2,
                                                      fake_labels)

            D_cLR_2_loss = D_cLR_2_real_loss + D_cLR_2_fake_loss

            # Calculate Total Discriminator Loss #
            D_loss = D_cVAE_1_loss + D_cVAE_2_loss + D_cLR_1_loss + D_cLR_2_loss

            # Back Propagation and Update #
            D_loss.backward()
            D_cVAE_optim.step()
            D_cLR_optim.step()

            set_requires_grad([D_cVAE, D_cLR], requires_grad=False)

            ###############################
            # Train Encoder and Generator #
            ###############################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Encode Latent Vector #
            mean, std = E(cVAE_data['target'])
            random_z = torch.randn(1, config.z_dim).to(device)
            encoded_z = mean + (random_z * std)

            # Generate Fake Image #
            fake_image_cVAE = G(cVAE_data['sketch'], encoded_z)
            prob_fake_D_cVAE_1, prob_fake_D_cVAE_2 = D_cVAE(fake_image_cVAE)

            # Adversarial Loss using cVAE #
            real_labels = torch.ones(prob_fake_D_cVAE_1.size()).to(device)
            E_G_adv_cVAE_1_loss = criterion_Adversarial(
                prob_fake_D_cVAE_1, real_labels)

            real_labels = torch.ones(prob_fake_D_cVAE_2.size()).to(device)
            E_G_adv_cVAE_2_loss = criterion_Adversarial(
                prob_fake_D_cVAE_2, real_labels)

            E_G_adv_cVAE_loss = E_G_adv_cVAE_1_loss + E_G_adv_cVAE_2_loss

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)
            prob_fake_D_cLR_1, prob_fake_D_cLR_2 = D_cLR(fake_image_cLR)

            # Adversarial Loss of cLR #
            real_labels = torch.ones(prob_fake_D_cLR_1.size()).to(device)
            E_G_adv_cLR_1_loss = criterion_Adversarial(prob_fake_D_cLR_1,
                                                       real_labels)

            real_labels = torch.ones(prob_fake_D_cLR_2.size()).to(device)
            E_G_adv_cLR_2_loss = criterion_Adversarial(prob_fake_D_cLR_2,
                                                       real_labels)

            E_G_adv_cLR_loss = E_G_adv_cLR_1_loss + E_G_adv_cLR_2_loss

            # KL Divergence with N ~ (0, 1) #
            E_KL_div_loss = config.lambda_KL * torch.sum(
                0.5 * (mean**2 + std - 2 * torch.log(std) - 1))

            # Reconstruction Loss #
            E_G_recon_loss = config.lambda_Image * criterion_Recon(
                fake_image_cVAE, cVAE_data['target'])

            # Total Encoder and Generator Loss ##
            E_G_loss = E_G_adv_cVAE_loss + E_G_adv_cLR_loss + E_KL_div_loss + E_G_recon_loss

            # Back Propagation and Update #
            E_G_loss.backward()
            E_optim.step()
            G_optim.step()

            ########################
            # Train Generator Only #
            ########################

            # Initialize Optimizers #
            D_cVAE_optim.zero_grad()
            D_cLR_optim.zero_grad()
            E_optim.zero_grad()
            G_optim.zero_grad()

            # Generate Fake Image using Random Latent Vector #
            random_z = torch.randn(1, config.z_dim).to(device)
            fake_image_cLR = G(cLR_data['sketch'], random_z)
            mean, std = E(fake_image_cLR)

            # Reconstruction Loss #
            G_recon_loss = criterion_Recon(mean, random_z)

            # Calculate Total Generator Loss #
            G_loss = config.lambda_Z * G_recon_loss

            # Back Propagation and Update #
            G_loss.backward()
            G_optim.step()

            # Add items to Lists #
            D_losses.append(D_loss.item())
            E_G_losses.append(E_G_loss.item())
            G_losses.append(G_loss.item())

            ####################
            # Print Statistics #
            ####################

            if (i + 1) % config.print_every == 0:
                print(
                    "BicycleGAN | Epochs [{}/{}] | Iterations [{}/{}] | D Loss {:.4f} | E_G Loss {:.4f} | G Loss {:.4f}"
                    .format(epoch + 1, config.num_epochs, i + 1, total_batch,
                            np.average(D_losses), np.average(E_G_losses),
                            np.average(G_losses)))

                # Save Sample Images #
                sample_images(val_loader, G, fixed_noise, epoch,
                              config.num_images, config.samples_path)

        # Adjust Learning Rate #
        D_cVAE_optim_scheduler.step()
        D_cLR_optim_scheduler.step()
        E_optim_scheduler.step()
        G_optim_scheduler.step()

        # Save Models #
        if (epoch + 1) % config.save_every == 0:
            torch.save(
                G.state_dict(),
                os.path.join(
                    config.weights_path,
                    'BicycleGAN_Generator_Epoch_{}.pkl'.format(epoch + 1)))

    # Make a GIF file #
    make_gifs_train("BicycleGAN", config.samples_path)

    # Plot Losses #
    plot_losses(D_losses, E_G_losses, G_losses, config.num_epochs,
                config.plots_path)

    print("Training finished.")
예제 #29
0
def main(args):
    print(hash(str(args)))
    if not os.path.exists(args.logs_path):
        os.makedirs(args.logs_path)
    logger = get_logger(f"{args.logs_path}/l{hash(str(args))}.log")
    logger.info(f"args: {str(args)}")
    logger.info(f"hash is: {hash(str(args))}")
    args.model_dir = f"{args.model_dir}/m{hash(str(args))}"
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)
    logger.info(f"checkpoint's dir is: {args.model_dir}")
    seed = hash(str(args)) % 1000_000
    T.manual_seed(seed)
    if use_gpu:
        T.cuda.manual_seed(seed)
    logger.info(f"random seed is: {seed}")

    logger.info("loading data...")
    with open("data/mscoco/dict.pckl", "rb") as f:
        d = pickle.load(f)
        word_to_idx = d["word_to_idx"]
        idx_to_word = d["idx_to_word"]
        bound_idx = word_to_idx["<S>"]
    train_features = np.load('data/mscoco/train_features.npy')
    valid_features = np.load('data/mscoco/valid_features.npy')
    logger.info('loaded...')

    args.vocab_size = len(word_to_idx)
    args.image_dim = valid_features.shape[1]

    train_dataset = ImageDataset(train_features)
    valid_dataset = ImageDataset(valid_features,
                                 mean=train_dataset.mean,
                                 std=train_dataset.std)

    # train_data = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
    # valid_data = DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True)

    train_data = DataLoader(train_dataset,
                            num_workers=8,
                            pin_memory=True,
                            batch_sampler=BatchSampler(
                                ImagesSampler(train_dataset, K, shuffle=True),
                                batch_size=args.batch_size,
                                drop_last=True))

    valid_data = DataLoader(valid_dataset,
                            num_workers=8,
                            pin_memory=True,
                            batch_sampler=BatchSampler(
                                ImagesSampler(valid_dataset, K, shuffle=False),
                                batch_size=args.batch_size,
                                drop_last=True))

    model = Game(vocab_size=args.vocab_size,
                 image_dim=args.image_dim,
                 s_embd_dim=args.sender_embd_dim,
                 s_hid_dim=args.sender_hid_dim,
                 r_embd_dim=args.receiver_embd_dim,
                 r_hid_dim=args.receiver_hid_dim,
                 bound_idx=bound_idx,
                 max_steps=args.max_sentence_len,
                 tau=args.tau,
                 straight_through=args.straight_through)
    if use_gpu:
        model = model.cuda(args.gpu_id)

    optimizer = T.optim.Adam(model.parameters(), lr=args.lr)
    lr_scheduler = get_lr_scheduler(logger, optimizer)
    es = EarlyStopping(mode="max",
                       patience=30,
                       threshold=0.005,
                       threshold_mode="rel")

    validate(valid_data, model, 0, args, logger)
    for epoch in range(args.max_epoch):
        train(train_data, model, optimizer, epoch, args, logger)
        val_accuracy = validate(valid_data, model, epoch, args, logger)
        if val_accuracy > train.prev_accuracy:
            logger.info("saving model...")
            T.save({
                "epoch": epoch,
                "state_dict": model.state_dict()
            }, f"{args.model_dir}/{epoch}.mdl")
            train.prev_accuracy = val_accuracy
        logger.info(90 * '=')
        lr_scheduler.step(val_accuracy)
        es.step(val_accuracy)
        if es.is_converged:
            logger.info("done")
            break