Training loops
Quick keras-like training using skorch
import skorch
from skorch import NeuralNetClassifier
# ---
# Data loading
# X,y = ...
# Model creation. the model is a regular pytorch module
# model = ...
# ---
# Dataset should be in one big numpy.ndarray
X = X.cpu().numpy().astype(numpy.float32)
y = Y.cpu().squeeze().numpy().astype(numpy.int64)
# Train a classifier
net = NeuralNetClassifier(
model,
max_epochs=300,
lr=0.001,
# Shuffle training data on each epoch
iterator_train__shuffle=True,
criterion=torch.nn.CrossEntropyLoss,
optimizer=torch.optim.Adam,
callbacks=[
('early_s', skorch.callbacks.EarlyStopping(monitor='valid_loss', patience=20, threshold=0.0001))
],
device='cuda',
# below are options specific to the model, with `module__` preprended to the options names
module__n_classes=n_classes,
module__dropout_probability=dropout_rate,
)
net.fit(X, y)
# Save trained model
filename = 'networks_params.pt'
torch.save(net.module_.state_dict(), filename)
GAN training loop
# TODO: check if it works
def train_conditional_gan_step(G, D, d_criterion, true_example, condition,
z_dim=128,
device='cuda',
batch_size=32,
n_critic=1,
clip_w_discriminator=None, clip_w_generator=None,
clip_gradients_discriminator=None, clip_gradients_generator=None):
# -------------------------------------------
# True/False
# -------------------------------------------
D_labels = torch.ones(batch_size, 1).to(device) # Discriminator Label to real
D_fakes = torch.zeros(batch_size, 1).to(device) # Discriminator Label to fake
# -------------------------------------------
# DISCRIMINATOR
# -------------------------------------------
# Note: this assumes condition is already encoded as a one-hot vector
x_outputs = D(true_example, condition)
D_x_loss = d_criterion(x_outputs, D_labels)
z = torch.randn(batch_size, z_dim).to(device)
zg = G(z, condition).detach()
z_outputs = D(zg, condition)
D_z_loss = d_criterion(z_outputs, D_fakes)
D_loss = D_x_loss + D_z_loss
D.zero_grad()
D_loss.backward()
# Clip gradients of discriminator
if clip_gradients_discriminator is not None:
torch.nn.utils.clip_grad_value_(D.parameters(), clip_gradients_discriminator) # usually: 0.1
# Clip weights of discriminator
if clip_w_discriminator is not None:
for p in D.parameters():
p.data.clamp_(-clip_w_discriminator, +clip_w_discriminator)
D_opt.step()
# -------------------------------------------
# GENERATOR
# -------------------------------------------
if step % n_critic == 0:
z = torch.randn(batch_size, z_dim).to(device)
z_outputs = D(G(z, condition), condition)
G_loss = d_criterion(z_outputs, D_labels)
G.zero_grad()
G_loss.backward()
# Clip gradients of generator
if clip_gradients_discriminator is not None:
torch.nn.utils.clip_grad_value_(G.parameters(), clip_gradients_generator) # usually: 0.1
# Clip weights of generator
if clip_w_generator is not None:
for p in G.parameters():
p.data.clamp_(-clip_w_generator, +clip_w_generator)
G_opt.step()
dloss = D_loss.item()
gloss = G_loss.item()
return dloss, gloss
# TODO: update everything below (for now, just copy and paste)
config = {
'device': device_name,
'D_params_clip_value': 0.01,
'n_critic': 1,
'batch_size': 256,
'dim_noise': 128,
'adam_lr_D': 4e-4,
'adam_lr_G': 4e-4,
'adam_beta_1_G': 0.9,
'adam_beta_2_G': 0.999,
'adam_beta_1_D': 0.9,
'adam_beta_2_D': 0.999,
'training_max_epoch': 5000
}
# create G, D, opts
# start recorders
# epoch loop
# - pre epoch
# - batch loop
# - pre batch
# - train step
# - post-batch
# - post-epoch
G = Generator(config['dim_noise']).to(device)
D = Discriminator().to(device)
d_criterion = torch.nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=config['adam_lr_D'], betas=(config['adam_beta_1_D'], config['adam_beta_2_D']))
G_opt = torch.optim.Adam(G.parameters(), lr=config['adam_lr_G'], betas=(config['adam_beta_1_G'], config['adam_beta_2_G']))
config['max_epoch'] = 5000 # need more than 10 epochs for training generator
step = 0
start = time.time()
print('[INFO] Started training on {} ...'.format(datetime.datetime.now()))
print('[INFO] Device: {}'.format(device_name))
print('[INFO] Generator params: {:,}'.format(count_parameters(G)))
print('[INFO] Discriminator params: {:,}'.format(count_parameters(D)))
tensorboard_logger = SummaryWriter(flush_secs=10)
tensorboard_logger.add_hparams(hparam_dict=config, metric_dict={})
path_images = Path(tensorboard_logger.get_logdir(), 'generated_sequences_images')
path_images.mkdir()
path_videos = Path(tensorboard_logger.get_logdir(), 'generated_sequences_videos')
path_videos.mkdir()
path_params = Path(tensorboard_logger.get_logdir(), 'networks_params')
path_params.mkdir()
with open(tensorboard_logger.get_logdir() + '/' + "generator.py", "w") as text_file:
text_file.write(source_code_G)
with open(tensorboard_logger.get_logdir() + '/' + "discriminator.py", "w") as text_file:
text_file.write(source_code_D)
save_model = True
print('[INFO] Logdir: {}'.format(tensorboard_logger.get_logdir()))
# tensorboard_logger.add_graph(G, torch.randn(1, config['dim_noise']).to(device))
# tensorboard_logger.add_graph(D, x_train[0].to(device))
print('[INFO] Logdir: {}'.format(tensorboard_logger.get_logdir()))
for epoch in tqdm.tqdm_notebook(range(step, step + config['max_epoch']), desc='Epoch'):
if step == 100:
for param in G.decoder.parameters():
param.requires_grad = True
for idx, (sequences, labels) in enumerate(tqdm.tqdm_notebook(data_loader, desc='Epoch', leave=False)):
# Training Discriminator
x = sequences.to(device)
y = to_onehot(labels).to(device)
dloss, gloss = train_conditional_gan_step(G, D, d_criterion, x, y,
z_dim=config['dim_noise'],
device=config['device'],
batch_size=config['batch_size'],
n_critic=config['n_critic'],
clip_w_discriminator=None,
clip_w_generator=None,
clip_gradients_discriminator=None,
clip_gradients_generator=None):
tensorboard_logger.add_scalar('loss/loss_D', dloss, step)
tensorboard_logger.add_scalar('loss/loss_G', gloss, step)
if step % 50 == 0:
print('Epoch: {}/{} | Step: {} | Time elapsed: {} | D Loss: {}, G Loss: {}'.format(epoch, config['max_epoch'], step, time_since(start), D_loss.item(), G_loss.item()))
view_every = 10 # 50 # 10
if step % view_every == 0: # and step != 0:
G.eval()
for c in range(config['condition_size']):
cond = torch.zeros(4, 1)
cond[:, 0] = c
cond = to_onehot(cond)
cond = cond.to(device)
generated_sequences = get_sample(G, config['dim_noise'], cond, how_many=4, as_numpy=False)
figure_filename = str(Path(path_images , 'generated_sequences_class_{:02d}_image_step_{:05d}.jpg'.format(c, step)))
video_filename = str(Path(path_videos , 'generated_sequences_class_{:02d}_video_step_{:05d}.mp4'.format(c, step)))
plot_four_examples(generated_sequences, save_as=figure_filename, display=False)
display_animation(block_to_alternated(generated_sequences[0]), display=False, save=True, filename=video_filename,
class_id=c, step=step)
print(' ...mean=', generated_sequences.mean())
G.train()
save_every = 40 # 500 # 50
if save_model is True and step % save_every == 0 and step != 0:
path_params = Path(tensorboard_logger.get_logdir(), 'networks_params', 'step__{:04d}'.format(step))
path_params.mkdir()
d_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__D.pkl'.format(step, version_notebook, model_name)))
g_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__G.pkl'.format(step, version_notebook, model_name)))
d_opt_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__D_opt.pkl'.format(step, version_notebook, model_name)))
g_opt_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__G_opt.pkl'.format(step, version_notebook, model_name)))
config_filename = str(Path(path_params, 'checkpoint_step_{}_notebook_v{}__{}__config.pkl'.format(step, version_notebook, model_name)))
torch.save(D.state_dict(), d_filename)
torch.save(G.state_dict(), g_filename)
torch.save(D_opt.state_dict(), d_opt_filename)
torch.save(G_opt.state_dict(), g_opt_filename)
torch.save(config, config_filename)
print('[INFO] Saved model in folder {}'.format(path_params))
step += 1